|
7 | 7 | from bayesflow.adapters import Adapter |
8 | 8 | from bayesflow.networks import InferenceNetwork, SummaryNetwork |
9 | 9 | from bayesflow.types import Tensor |
10 | | -from bayesflow.utils import filter_kwargs, logging, split_arrays, squeeze_inner_estimates_dict, concatenate_valid |
| 10 | +from bayesflow.utils import ( |
| 11 | + filter_kwargs, |
| 12 | + logging, |
| 13 | + split_arrays, |
| 14 | + squeeze_inner_estimates_dict, |
| 15 | + concatenate_valid, |
| 16 | + concatenate_valid_shapes, |
| 17 | +) |
11 | 18 | from bayesflow.utils.serialization import serialize, deserialize, serializable |
12 | 19 |
|
13 | 20 | from .approximator import Approximator |
@@ -60,6 +67,28 @@ def __init__( |
60 | 67 | else: |
61 | 68 | self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize} |
62 | 69 |
|
| 70 | + def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None: |
| 71 | + summary_outputs_shape = None |
| 72 | + inference_conditions_shape = data_shapes.get("inference_conditions", None) |
| 73 | + if self.summary_network is not None: |
| 74 | + self.summary_network.build(data_shapes["summary_variables"]) |
| 75 | + summary_outputs_shape = self.summary_network.compute_output_shape(data_shapes["summary_variables"]) |
| 76 | + inference_conditions_shape = concatenate_valid_shapes( |
| 77 | + [inference_conditions_shape, summary_outputs_shape], axis=-1 |
| 78 | + ) |
| 79 | + self.inference_network.build(data_shapes["inference_variables"], inference_conditions_shape) |
| 80 | + if self.standardize == "all": |
| 81 | + self.standardize = [ |
| 82 | + var |
| 83 | + for var in ["inference_variables", "summary_variables", "inference_conditions"] |
| 84 | + if var in data_shapes |
| 85 | + ] |
| 86 | + |
| 87 | + self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize} |
| 88 | + for var, layer in self.standardize_layers.items(): |
| 89 | + layer.build(data_shapes[var]) |
| 90 | + self.built = True |
| 91 | + |
63 | 92 | @classmethod |
64 | 93 | def build_adapter( |
65 | 94 | cls, |
@@ -120,16 +149,7 @@ def compile( |
120 | 149 | return super().compile(*args, **kwargs) |
121 | 150 |
|
122 | 151 | def build_from_data(self, adapted_data: dict[str, any]): |
123 | | - if self.standardize == "all": |
124 | | - self.standardize = [ |
125 | | - var |
126 | | - for var in ["inference_variables", "summary_variables", "inference_conditions"] |
127 | | - if var in adapted_data |
128 | | - ] |
129 | | - |
130 | | - self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize} |
131 | | - |
132 | | - super().build_from_data(adapted_data) |
| 152 | + self.build(keras.tree.map_structure(keras.ops.shape, adapted_data)) |
133 | 153 |
|
134 | 154 | def compile_from_config(self, config): |
135 | 155 | self.compile(**deserialize(config)) |
|
0 commit comments