Skip to content

Commit 399a1b4

Browse files
committed
approximator builds: add guards against building networks twice
1 parent 5c529a2 commit 399a1b4

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
7171
summary_outputs_shape = None
7272
inference_conditions_shape = data_shapes.get("inference_conditions", None)
7373
if self.summary_network is not None:
74-
self.summary_network.build(data_shapes["summary_variables"])
74+
if not self.summary_network.built:
75+
self.summary_network.build(data_shapes["summary_variables"])
7576
summary_outputs_shape = self.summary_network.compute_output_shape(data_shapes["summary_variables"])
7677
inference_conditions_shape = concatenate_valid_shapes(
7778
[inference_conditions_shape, summary_outputs_shape], axis=-1
7879
)
79-
self.inference_network.build(data_shapes["inference_variables"], inference_conditions_shape)
80+
if not self.inference_network.built:
81+
self.inference_network.build(data_shapes["inference_variables"], inference_conditions_shape)
8082
if self.standardize == "all":
8183
self.standardize = [
8284
var

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,16 @@ def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
7070
summary_outputs_shape = None
7171
classifier_conditions_shape = data_shapes.get("classifier_conditions", None)
7272
if self.summary_network is not None:
73-
self.summary_network.build(data_shapes["summary_variables"])
73+
if not self.summary_network.built:
74+
self.summary_network.build(data_shapes["summary_variables"])
7475
summary_outputs_shape = self.summary_network.compute_output_shape(data_shapes["summary_variables"])
7576
classifier_conditions_shape = concatenate_valid_shapes(
7677
[classifier_conditions_shape, summary_outputs_shape], axis=-1
7778
)
78-
self.classifier_network.build(classifier_conditions_shape)
79-
self.logits_projector.build(self.classifier_network.compute_output_shape(classifier_conditions_shape))
79+
if not self.classifier_network.built:
80+
self.classifier_network.build(classifier_conditions_shape)
81+
if not self.logits_projector.built:
82+
self.logits_projector.build(self.classifier_network.compute_output_shape(classifier_conditions_shape))
8083
if self.standardize == "all":
8184
self.standardize = [var for var in ["summary_variables", "classifier_conditions"] if var in data_shapes]
8285

tests/test_two_moons/test_two_moons.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def test_serialize_deserialize(tmp_path, approximator, train_dataset):
6060
mock_data_shapes = keras.tree.map_structure(keras.ops.shape, mock_data)
6161
approximator.build(mock_data_shapes)
6262

63+
# run a single batch through the approximator
64+
approximator.compute_metrics(**mock_data)
65+
6366
keras.saving.save_model(approximator, tmp_path / "model.keras")
6467
loaded_approximator = keras.saving.load_model(tmp_path / "model.keras")
6568

0 commit comments

Comments
 (0)