Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
ceab303
Add standardization to continuous approximator and test
stefanradev93 May 22, 2025
d79b17a
Fix init bugs, adapt tnotebooks
stefanradev93 May 23, 2025
c777122
Add training flag to build_from_data
stefanradev93 May 23, 2025
7aeb9cb
Fix inference conditions check
stefanradev93 May 23, 2025
45ab9ea
Fix tests
stefanradev93 May 23, 2025
a83770a
Remove unnecessary init calls
stefanradev93 May 23, 2025
4df270a
Add deprecation warning
stefanradev93 May 24, 2025
8ea6782
Refactor compute metrics and add standardization to model comp
stefanradev93 May 25, 2025
b2a4f76
Fix standardization in cont approx
stefanradev93 May 26, 2025
deffc27
Fix sample keys -> condition keys
stefanradev93 May 26, 2025
43af4bd
amazing keras fix
LarsKue May 26, 2025
039fc8d
moving_mean and moving_std still not loading [WIP]
stefanradev93 May 26, 2025
02ded97
remove hacky approximator serialization test
LarsKue May 27, 2025
54d860e
fix building of models in tests
LarsKue May 27, 2025
2a86cc3
Fix standardization
stefanradev93 May 27, 2025
1df9269
Add standardizatrion to model comp and let it use inheritance
stefanradev93 May 27, 2025
49af469
make assert_models/layers_equal more thorough
LarsKue May 27, 2025
1fdde32
Merge remote-tracking branch 'origin/standardize_in_approx' into stan…
LarsKue May 27, 2025
0869e3f
[no ci] use map_shape_structure to convert shapes to arrays
vpratz May 31, 2025
1a845e3
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz Jun 1, 2025
bd2725d
Extend Standardization to support nested inputs (#501)
vpratz Jun 1, 2025
c5fb949
Update moments before transform and update test
stefanradev93 Jun 1, 2025
100d7c0
Update notebooks
stefanradev93 Jun 1, 2025
905bf05
Merge dev into branch
stefanradev93 Jun 1, 2025
38f2228
Refactor and simplify due to standardize
stefanradev93 Jun 1, 2025
0c24db2
Add comment for fetching the dict's first item, deprecate logits arg …
stefanradev93 Jun 2, 2025
5755135
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz Jun 2, 2025
4fa1bbb
add missing import in test
vpratz Jun 2, 2025
b2bfeea
Refactor preparation of data for networks and new point_appr.log_prob
han-ol Jun 3, 2025
5773d28
Merge branch 'standardize_in_approx' of https://github.com/bayesflow-…
stefanradev93 Jun 3, 2025
392d9f7
Add class attributes to inform proper standardization
han-ol Jun 4, 2025
2d5b2fb
Implement stable moving mean and std
stefanradev93 Jun 4, 2025
bde587c
Merge and add incremental moments
stefanradev93 Jun 4, 2025
1b2b5be
Adapt and fix tests
stefanradev93 Jun 4, 2025
d406a29
minor adaptations to moving average (update time, init)
vpratz Jun 5, 2025
a503bd9
increase tolerance of allclose tests
vpratz Jun 5, 2025
caf0491
[no ci] set trainable to False explicitly in ModelComparisonApproximator
vpratz Jun 5, 2025
dd24941
Merge branch 'standardize_in_approx' of https://github.com/bayesflow-…
stefanradev93 Jun 5, 2025
8268128
point estimate of covariance compatible with standardization
han-ol Jun 6, 2025
e32ae2e
properly set values to zero if std is zero
vpratz Jun 6, 2025
b7d6c0e
fix sample post-processing in point approximator
vpratz Jun 6, 2025
00d72ab
activate tests for multivariate normal score
vpratz Jun 6, 2025
c2ebd23
[no ci] undo prev commit: MVN test still not stable, was hidden by st…
vpratz Jun 6, 2025
cd45b85
specify explicit build functions for approximators
vpratz Jun 6, 2025
3f28f34
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz Jun 6, 2025
0952a29
set std for untrained standardization layer to one
vpratz Jun 6, 2025
5c529a2
[no ci] reformulate zero std case
vpratz Jun 6, 2025
399a1b4
approximator builds: add guards against building networks twice
vpratz Jun 6, 2025
dd0dc87
[no ci] add comparison with loaded approx to workflow test
vpratz Jun 6, 2025
d28df75
Cleanup and address building standardization layers when None specified
stefanradev93 Jun 6, 2025
40d2d1d
Cleanup and address building standardization layers when None specifi…
stefanradev93 Jun 6, 2025
c6d79ae
Add default case for std transform and add transformation to doc.
stefanradev93 Jun 6, 2025
df1761b
adapt handling of the special case M^2=0
vpratz Jun 7, 2025
3b93251
[no ci] minor fix in concatenate_valid_shapes
vpratz Jun 7, 2025
65cac46
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz Jun 7, 2025
1944186
[no ci] extend test suite for approximators
vpratz Jun 7, 2025
1ebf1cd
fixes for standardize=None case
vpratz Jun 7, 2025
71cd6b9
skip unstable MVN score case
vpratz Jun 7, 2025
3f0f9d1
Better transformation types
han-ol Jun 9, 2025
a3b59c3
Add test for both_sides_scale inverse standardization
han-ol Jun 9, 2025
183f608
Add test for left_side_scale inverse standardization
han-ol Jun 9, 2025
f0de38b
Remove flaky test failing due to sampling error
han-ol Jun 9, 2025
43ced5b
Fix input dtypes in inverse standardization transformation_type tests
han-ol Jun 9, 2025
c3e945e
Merge branch 'dev' into standardize_in_approx
han-ol Jun 9, 2025
82e28a7
Use concatenate_valid in _sample
han-ol Jun 10, 2025
ef97a6c
Replace PositiveDefinite link with CholeskyFactor
han-ol Jun 10, 2025
24c268b
Reintroduce test sampling with MVN score
han-ol Jun 10, 2025
e45f260
Address TODOs and adapt docstrings and workflow
stefanradev93 Jun 11, 2025
333c30f
Adapt notebooks
stefanradev93 Jun 11, 2025
48bb190
Fix in model comparison
stefanradev93 Jun 11, 2025
fd83567
Update readme and add point estimation nb
stefanradev93 Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bayesflow/adapters/transforms/standardize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
import warnings

import numpy as np

Expand Down Expand Up @@ -69,6 +70,14 @@ def __init__(
):
super().__init__()

if mean is None or std is None:
warnings.warn(
"Dynamic standardization is deprecated and will be removed in later versions."
"Instead, use the standardize argument of the approximator / workflow instance or provide "
"fixed mean and std arguments. You may incur some redundant computations if you keep this transform.",
DeprecationWarning,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to have a convenience function that calculates mean and std for a dataset, in the format that would be required here. We could also advertise it in the deprecation warning. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but such a function will not be very efficient when the entire data set is not (yet) in memory. I see its use mainly for OfflineDataset.

self.mean = mean
self.std = std

Expand Down
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __repr__(self):

@classmethod
def from_config(cls, config: dict, custom_objects=None):
# noinspection PyArgumentList
return cls(**deserialize(config, custom_objects=custom_objects))

def get_config(self) -> dict:
Expand Down
14 changes: 9 additions & 5 deletions bayesflow/approximators/approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@


class Approximator(BackendApproximator):
def build(self, data_shapes: any) -> None:
mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes)
def build(self, data_shapes: dict[str, tuple[int]]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our type hints currently do not reflect nested structures. Do we leave it like that for now or do we want to adapt them? What would be a good hint in that case, just dict[str, tuple[int] | dict[str, dict]], or something more involved?

Copy link
Contributor Author

@stefanradev93 stefanradev93 Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep it simple for now as you suggested: dict[str, tuple[int] | dict[str, dict]].

mock_data = keras.tree.map_shape_structure(keras.ops.zeros, data_shapes)
self.build_from_data(mock_data)

@classmethod
def build_adapter(cls, **kwargs) -> Adapter:
# implemented by each respective architecture
raise NotImplementedError

def build_from_data(self, data: dict[str, any]) -> None:
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
def build_from_data(self, adapted_data: dict[str, any]) -> None:
self.compute_metrics(**filter_kwargs(adapted_data, self.compute_metrics), stage="training")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call can error (with a very ugly/uninformative error message) in Jax if the inference and summary networks do not check themselves whether they are already built (see 01aadf1 for a fix where this was the problem). We could think about using a NotImplementedError here and trying to do a proper build in the subclasses, but I'm not sure if it is worth it right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call was also present before. What is essential is the call to build on L139 of the Approximator, because keras wraps the build function internally. I think I like the idea of re-implementing build in the subclasses. @LarsKue What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry for not making it more explicit, this problem is older and not related to the changes here, we can also move this to a separate issue if you like

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My biggest issue with this right now is that the build_from data runs the networks with a bunch of zeros, which is bad for initializing the running means and stds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably address it here. This turned out to be one of those monstrous PRs which start with a harmless feature and end up in a gargantuan refactor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to replace the build method entirely if we want to reliably build on real data.

self.built = True

@classmethod
Expand Down Expand Up @@ -61,6 +61,9 @@
max_queue_size=max_queue_size,
)

def call(self, *args, **kwargs):
return self.compute_metrics(*args, **kwargs)

Check warning on line 65 in bayesflow/approximators/approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator.py#L65

Added line #L65 was not covered by tests

def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs):
"""
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
Expand Down Expand Up @@ -132,6 +135,7 @@
logging.info("Building on a test batch.")
mock_data = dataset[0]
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
self.build_from_data(mock_data)
mock_data_shapes = keras.tree.map_structure(keras.ops.shape, mock_data)
self.build(mock_data_shapes)

return super().fit(dataset=dataset, **kwargs)
Loading
Loading