Skip to content

Commit cd45b85

Browse files
committed
specify explicit build functions for approximators
1 parent c2ebd23 commit cd45b85

File tree

6 files changed

+83
-20
lines changed

6 files changed

+83
-20
lines changed

bayesflow/approximators/approximator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,15 @@
1212

1313
class Approximator(BackendApproximator):
1414
def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
15-
mock_data = keras.tree.map_shape_structure(keras.ops.zeros, data_shapes)
16-
self.build_from_data(mock_data)
15+
raise NotImplementedError
1716

1817
@classmethod
1918
def build_adapter(cls, **kwargs) -> Adapter:
2019
# implemented by each respective architecture
2120
raise NotImplementedError
2221

2322
def build_from_data(self, adapted_data: dict[str, any]) -> None:
24-
self.compute_metrics(**filter_kwargs(adapted_data, self.compute_metrics), stage="training")
25-
self.built = True
23+
raise NotImplementedError
2624

2725
@classmethod
2826
def build_dataset(

bayesflow/approximators/continuous_approximator.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
from bayesflow.adapters import Adapter
88
from bayesflow.networks import InferenceNetwork, SummaryNetwork
99
from bayesflow.types import Tensor
10-
from bayesflow.utils import filter_kwargs, logging, split_arrays, squeeze_inner_estimates_dict, concatenate_valid
10+
from bayesflow.utils import (
11+
filter_kwargs,
12+
logging,
13+
split_arrays,
14+
squeeze_inner_estimates_dict,
15+
concatenate_valid,
16+
concatenate_valid_shapes,
17+
)
1118
from bayesflow.utils.serialization import serialize, deserialize, serializable
1219

1320
from .approximator import Approximator
@@ -60,6 +67,28 @@ def __init__(
6067
else:
6168
self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize}
6269

70+
def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
71+
summary_outputs_shape = None
72+
inference_conditions_shape = data_shapes.get("inference_conditions", None)
73+
if self.summary_network is not None:
74+
self.summary_network.build(data_shapes["summary_variables"])
75+
summary_outputs_shape = self.summary_network.compute_output_shape(data_shapes["summary_variables"])
76+
inference_conditions_shape = concatenate_valid_shapes(
77+
[inference_conditions_shape, summary_outputs_shape], axis=-1
78+
)
79+
self.inference_network.build(data_shapes["inference_variables"], inference_conditions_shape)
80+
if self.standardize == "all":
81+
self.standardize = [
82+
var
83+
for var in ["inference_variables", "summary_variables", "inference_conditions"]
84+
if var in data_shapes
85+
]
86+
87+
self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize}
88+
for var, layer in self.standardize_layers.items():
89+
layer.build(data_shapes[var])
90+
self.built = True
91+
6392
@classmethod
6493
def build_adapter(
6594
cls,
@@ -120,16 +149,7 @@ def compile(
120149
return super().compile(*args, **kwargs)
121150

122151
def build_from_data(self, adapted_data: dict[str, any]):
123-
if self.standardize == "all":
124-
self.standardize = [
125-
var
126-
for var in ["inference_variables", "summary_variables", "inference_conditions"]
127-
if var in adapted_data
128-
]
129-
130-
self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize}
131-
132-
super().build_from_data(adapted_data)
152+
self.build(keras.tree.map_structure(keras.ops.shape, adapted_data))
133153

134154
def compile_from_config(self, config):
135155
self.compile(**deserialize(config))

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bayesflow.networks import SummaryNetwork
99
from bayesflow.simulators import ModelComparisonSimulator, Simulator
1010
from bayesflow.types import Tensor
11-
from bayesflow.utils import filter_kwargs, logging, concatenate_valid
11+
from bayesflow.utils import filter_kwargs, logging, concatenate_valid, concatenate_valid_shapes
1212
from bayesflow.utils.serialization import serialize, deserialize, serializable
1313

1414
from .approximator import Approximator
@@ -66,11 +66,27 @@ def __init__(
6666
else:
6767
self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize}
6868

69-
def build_from_data(self, adapted_data: dict[str, any]):
69+
def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
70+
summary_outputs_shape = None
71+
classifier_conditions_shape = data_shapes.get("classifier_conditions", None)
72+
if self.summary_network is not None:
73+
self.summary_network.build(data_shapes["summary_variables"])
74+
summary_outputs_shape = self.summary_network.compute_output_shape(data_shapes["summary_variables"])
75+
classifier_conditions_shape = concatenate_valid_shapes(
76+
[classifier_conditions_shape, summary_outputs_shape], axis=-1
77+
)
78+
self.classifier_network.build(classifier_conditions_shape)
79+
self.logits_projector.build(self.classifier_network.compute_output_shape(classifier_conditions_shape))
7080
if self.standardize == "all":
71-
self.standardize = [var for var in ["summary_variables", "classifier_conditions"] if var in adapted_data]
81+
self.standardize = [var for var in ["summary_variables", "classifier_conditions"] if var in data_shapes]
82+
7283
self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize}
73-
super().build_from_data(adapted_data)
84+
for var, layer in self.standardize_layers.items():
85+
layer.build(data_shapes[var])
86+
self.built = True
87+
88+
def build_from_data(self, adapted_data: dict[str, any]):
89+
self.build(keras.tree.map_structure(keras.ops.shape(adapted_data)))
7490

7591
@classmethod
7692
def build_adapter(

bayesflow/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777

7878
from .tensor_utils import (
7979
concatenate_valid,
80+
concatenate_valid_shapes,
8081
expand,
8182
expand_as,
8283
expand_to,

bayesflow/utils/tensor_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import keras
55
import numpy as np
66

7-
from bayesflow.types import Tensor
7+
from bayesflow.types import Tensor, Shape
88
from . import logging
99

1010
T = TypeVar("T")
@@ -20,6 +20,17 @@ def concatenate_valid(tensors: Sequence[Tensor | None], axis: int = 0) -> Tensor
2020
return keras.ops.concatenate(tensors, axis=axis)
2121

2222

23+
def concatenate_valid_shapes(tensor_shapes: Sequence[Shape | None], axis: int = 0) -> Shape | None:
24+
tensor_shapes = [s for s in tensor_shapes if s is not None]
25+
if not tensor_shapes:
26+
return None
27+
28+
output_shape = tensor_shapes[0]
29+
for s in tensor_shapes[1:]:
30+
output_shape[axis] += s[axis]
31+
return output_shape
32+
33+
2334
def expand(x: Tensor, n: int, side: str):
2435
if n < 0:
2536
raise ValueError(f"Cannot expand {n} times.")
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import keras
2+
from tests.utils import check_combination_simulator_adapter
3+
4+
5+
def test_build(approximator, simulator, batch_size, adapter):
6+
check_combination_simulator_adapter(simulator, adapter)
7+
8+
num_batches = 4
9+
data = simulator.sample((num_batches * batch_size,))
10+
11+
batch = adapter(data)
12+
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
13+
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
14+
approximator.build(batch_shapes)
15+
for layer in approximator.standardize_layers.values():
16+
assert layer.built
17+
assert layer.count == 0

0 commit comments

Comments
 (0)