Commit 68c6c92
[WIP] Move standardization into approximators and make adapter stateless. (#486)
* Add standardization to continuous approximator and test
* Fix init bugs, adapt tnotebooks
* Add training flag to build_from_data
* Fix inference conditions check
* Fix tests
* Remove unnecessary init calls
* Add deprecation warning
* Refactor compute metrics and add standardization to model comp
* Fix standardization in cont approx
* Fix sample keys -> condition keys
* amazing keras fix
* moving_mean and moving_std still not loading [WIP]
* remove hacky approximator serialization test
* fix building of models in tests
* Fix standardization
* Add standardizatrion to model comp and let it use inheritance
* make assert_models/layers_equal more thorough
* [no ci] use map_shape_structure to convert shapes to arrays
This automatically takes care of nested structures.
* Extend Standardization to support nested inputs (#501)
* extend Standardization to nested inputs
By using `keras.tree.flatten` und `keras.tree.pack_sequence_as`, we can
support arbitrary nested structures. A `flatten_shape` function is
introduced, analogous to `map_shape_structure`, for use in the build
function.
* keep tree utils in submodule
* Streamline call
* Fix typehint
---------
Co-authored-by: stefanradev93 <[email protected]>
* Update moments before transform and update test
* Update notebooks
* Refactor and simplify due to standardize
* Add comment for fetching the dict's first item, deprecate logits arg and fix typehint
* add missing import in test
* Refactor preparation of data for networks and new point_appr.log_prob
* ContinuousApproximator._prepare_data unifies all preparation in
sample, log_prob and estimate for both ContinuousApproximator and
PointApproximator
* PointApproximator now overrides log_prob
* Add class attributes to inform proper standardization
* Implement stable moving mean and std
* Adapt and fix tests
* minor adaptations to moving average (update time, init)
We should put the update before the standardization, to use the maximum
amount of information available. We can then also initialize the moving
M^2 with zero, as it will be filled immediately.
The special case of M^2 = 0 is not problematic, as no variance
automatically indicates that all entries are equal, and we can set
them to zero (see my comment).
I added another test case to cover that case, and added a test for the
standard deviation to the existing test.
* increase tolerance of allclose tests
* [no ci] set trainable to False explicitly in ModelComparisonApproximator
* point estimate of covariance compatible with standardization
* properly set values to zero if std is zero
Cases for inf and -inf were missing
* fix sample post-processing in point approximator
* activate tests for multivariate normal score
* [no ci] undo prev commit: MVN test still not stable, was hidden by std of 0
* specify explicit build functions for approximators
* set std for untrained standardization layer to one
An untrained layer thereby does not modify the input.
* [no ci] reformulate zero std case
* approximator builds: add guards against building networks twice
* [no ci] add comparison with loaded approx to workflow test
* Cleanup and address building standardization layers when None specified
* Cleanup and address building standardization layers when None specified 2
* Add default case for std transform and add transformation to doc.
* adapt handling of the special case M^2=0
* [no ci] minor fix in concatenate_valid_shapes
* [no ci] extend test suite for approximators
* fixes for standardize=None case
* skip unstable MVN score case
* Better transformation types
* Add test for both_sides_scale inverse standardization
* Add test for left_side_scale inverse standardization
* Remove flaky test failing due to sampling error
* Fix input dtypes in inverse standardization transformation_type tests
* Use concatenate_valid in _sample
* Replace PositiveDefinite link with CholeskyFactor
This finally makes the MVN score sampling test stable for the jax backend,
for which the keras.ops.cholesky operation is numerically unstable.
The score's sample method avoids calling keras.ops.cholesky to resolve
the issue. Instead the estimation head returns the Cholesky factor
directly rather than the covariance matrix (as it used to be).
* Reintroduce test sampling with MVN score
* Address TODOs and adapt docstrings and workflow
* Adapt notebooks
* Fix in model comparison
* Update readme and add point estimation nb
---------
Co-authored-by: LarsKue <[email protected]>
Co-authored-by: Valentin Pratz <[email protected]>
Co-authored-by: Valentin Pratz <[email protected]>
Co-authored-by: han-ol <[email protected]>
Co-authored-by: Hans Olischläger <[email protected]>1 parent 735969c commit 68c6c92
File tree
41 files changed
+3560
-2681
lines changed- bayesflow
- adapters/transforms
- approximators
- links
- networks/standardization
- scores
- utils
- workflows
- examples
- tests
- test_approximators
- test_approximator_standardization
- test_model_comparison_approximator
- test_point_approximators
- test_links
- test_networks
- test_two_moons
- test_workflows
- utils
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
41 files changed
+3560
-2681
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
130 | 130 | | |
131 | 131 | | |
132 | 132 | | |
133 | | - | |
| 133 | + | |
134 | 134 | | |
135 | 135 | | |
136 | 136 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
| |||
69 | 70 | | |
70 | 71 | | |
71 | 72 | | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
72 | 81 | | |
73 | 82 | | |
74 | 83 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
11 | 11 | | |
12 | 12 | | |
13 | 13 | | |
14 | | - | |
15 | | - | |
16 | | - | |
| 14 | + | |
| 15 | + | |
17 | 16 | | |
18 | 17 | | |
19 | 18 | | |
20 | 19 | | |
21 | 20 | | |
22 | 21 | | |
23 | | - | |
24 | | - | |
25 | | - | |
| 22 | + | |
| 23 | + | |
26 | 24 | | |
27 | 25 | | |
28 | 26 | | |
| |||
61 | 59 | | |
62 | 60 | | |
63 | 61 | | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
64 | 65 | | |
65 | 66 | | |
66 | 67 | | |
| |||
132 | 133 | | |
133 | 134 | | |
134 | 135 | | |
135 | | - | |
| 136 | + | |
| 137 | + | |
136 | 138 | | |
137 | 139 | | |
0 commit comments