Skip to content

Commit 43af4bd

Browse files
committed
amazing keras fix
1 parent deffc27 commit 43af4bd

File tree

5 files changed

+65
-17
lines changed

5 files changed

+65
-17
lines changed

bayesflow/approximators/approximator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212

1313
class Approximator(BackendApproximator):
14-
def build(self, data_shapes: any) -> None:
15-
mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes)
14+
def build(self, data_shapes: dict[str, tuple[int]]) -> None:
15+
mock_data = {key: keras.ops.zeros(value) for key, value in data_shapes.items()}
1616
self.build_from_data(mock_data)
1717

1818
@classmethod
@@ -61,6 +61,9 @@ def build_dataset(
6161
max_queue_size=max_queue_size,
6262
)
6363

64+
def call(self, *args, **kwargs):
65+
return self.compute_metrics(*args, **kwargs)
66+
6467
def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs):
6568
"""
6669
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.

bayesflow/approximators/continuous_approximator.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,23 @@ def __init__(
4949
self.inference_network = inference_network
5050
self.summary_network = summary_network
5151

52-
if standardize == "all":
53-
standardize = ["inference_variables", "summary_variables", "inference_conditions"]
54-
elif isinstance(standardize, str):
55-
standardize = [standardize]
56-
elif isinstance(standardize, Sequence):
57-
standardize = standardize
58-
else:
59-
standardize = []
52+
# if standardize == "all":
53+
# standardize = ["inference_variables", "summary_variables", "inference_conditions"]
54+
# elif isinstance(standardize, str):
55+
# standardize = [standardize]
56+
# elif isinstance(standardize, Sequence):
57+
# standardize = standardize
58+
# else:
59+
# standardize = []
6060

6161
self.standardize = standardize
62-
self.standardize_layers = {s: Standardization() for s in standardize}
62+
63+
if standardize == "all":
64+
# we have to lazily initialize these
65+
self.standardize_layers = None
66+
else:
67+
print("eager init")
68+
self.standardize_layers = {s: Standardization(trainable=False) for s in self.standardize}
6369

6470
@classmethod
6571
def build_adapter(
@@ -121,7 +127,16 @@ def compile(
121127
return super().compile(*args, **kwargs)
122128

123129
def build_from_data(self, adapted_data: dict[str, any]):
130+
if self.standardize == "all":
131+
self.standardize = list(adapted_data.keys())
132+
self.standardize = ["inference_variables", "summary_variables", "inference_conditions"]
133+
self.standardize = list(filter(lambda x: x in adapted_data, self.standardize))
134+
135+
if self.standardize_layers is None:
136+
self.standardize_layers = {s: Standardization(trainable=False) for s in self.standardize}
137+
124138
self.compute_metrics(**filter_kwargs(adapted_data, self.compute_metrics), stage="training")
139+
125140
self.built = True
126141

127142
def compile_from_config(self, config):
@@ -207,11 +222,11 @@ def _compute_summary_metrics(self, summary_variables: Tensor | None, stage: str)
207222
summary_outputs = summary_metrics.pop("outputs")
208223
return summary_metrics, summary_outputs
209224

210-
def _prepare_inference_variables(self, inference_variables, stage):
225+
def _prepare_inference_variables(self, inference_variables: Tensor, stage: str) -> Tensor:
211226
"""Helper function to convert inference variables to tensors and optionally standardize them."""
212-
inference_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, inference_variables)
213227
if "inference_variables" in self.standardize:
214228
inference_variables = self.standardize_layers["inference_variables"](inference_variables, stage=stage)
229+
215230
return inference_variables
216231

217232
def _combine_conditions(

bayesflow/networks/standardization/standardization.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
from bayesflow.types import Tensor, Shape
66
from bayesflow.utils.serialization import serialize, deserialize, serializable
7-
from bayesflow.utils import expand_left_as
7+
from bayesflow.utils import expand_left_as, layer_kwargs
88

99

1010
@serializable("bayesflow.networks")
1111
class Standardization(keras.Layer):
12-
def __init__(self, momentum: float = 0.95, epsilon: float = 1e-6):
12+
def __init__(self, momentum: float = 0.95, epsilon: float = 1e-6, **kwargs):
1313
"""
1414
Initializes a Standardization layer that will keep track of the running mean and
1515
running standard deviation across a batch of tensors.
@@ -23,14 +23,17 @@ def __init__(self, momentum: float = 0.95, epsilon: float = 1e-6):
2323
epsilon: float, optional
2424
Stability parameter to avoid division by zero.
2525
"""
26-
super().__init__()
26+
super().__init__(**layer_kwargs(kwargs))
2727

2828
self.momentum = momentum
2929
self.epsilon = epsilon
3030
self.moving_mean = None
3131
self.moving_std = None
3232

3333
def build(self, input_shape: Shape):
34+
if self.built:
35+
return
36+
3437
self.moving_mean = self.add_weight(shape=(input_shape[-1],), initializer="zeros", trainable=False)
3538
self.moving_std = self.add_weight(shape=(input_shape[-1],), initializer="ones", trainable=False)
3639

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import keras
2+
import numpy as np
3+
4+
from ..utils import assert_models_equal
5+
6+
7+
def test_serialize_deserialize_continuous_approximator(tmp_path, continuous_approximator):
8+
sample_data = {
9+
"mean": np.zeros((32, 10, 2)),
10+
"std": np.ones((32, 10, 1)),
11+
"x": np.random.standard_normal((32, 10, 2)),
12+
}
13+
14+
sample_data = continuous_approximator.adapter(sample_data)
15+
16+
continuous_approximator.build_from_data(sample_data)
17+
18+
keras.saving.save_model(continuous_approximator, tmp_path / "model.keras")
19+
loaded = keras.saving.load_model(tmp_path / "model.keras")
20+
assert_models_equal(continuous_approximator, loaded)
21+
22+
# serialized = serialize(continuous_approximator)
23+
# deserialized = deserialize(serialized)
24+
# reserialized = serialize(deserialized)
25+
#
26+
# assert serialized == reserialized

tests/test_two_moons/test_two_moons.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def test_fit(approximator, train_dataset, validation_dataset, batch_size):
2323

2424
mock_data = train_dataset[0]
2525
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
26-
approximator.build_from_data(mock_data)
26+
mock_data_shapes = keras.tree.map_structure(keras.ops.shape, mock_data)
27+
approximator.build(mock_data_shapes)
2728

2829
untrained_weights = copy.deepcopy(approximator.weights)
2930
untrained_metrics = approximator.evaluate(validation_dataset, return_dict=True)

0 commit comments

Comments
 (0)