Skip to content

Commit 5da3759

Browse files
authored
Merge pull request #306 from vpratz/fix-adapter-serialization
Fix: make transforms AsSet and AsTimeSeries serializable
2 parents 9071be4 + 0a4921b commit 5da3759

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

bayesflow/adapters/transforms/as_set.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
3333
return np.squeeze(data, axis=2)
3434

3535
return data
36+
37+
@classmethod
38+
def from_config(cls, config: dict, custom_objects=None) -> "AsSet":
39+
return cls()
40+
41+
def get_config(self) -> dict:
42+
return {}

bayesflow/adapters/transforms/as_time_series.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
3030
return np.squeeze(data, axis=2)
3131

3232
return data
33+
34+
@classmethod
35+
def from_config(cls, config: dict, custom_objects=None) -> "AsTimeSeries":
36+
return cls()
37+
38+
def get_config(self) -> dict:
39+
return {}

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,11 @@ def compute_metrics(
191191
else:
192192
# not pre-configured, resample
193193
x1 = x
194-
x0 = self.base_distribution.sample(keras.ops.shape(x1), seed=self.seed_generator)
194+
if not self.built:
195+
xz_shape = keras.ops.shape(x1)
196+
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
197+
self.build(xz_shape, conditions_shape)
198+
x0 = self.base_distribution.sample(keras.ops.shape(x1)[:-1])
195199

196200
if self.use_optimal_transport:
197201
x1, x0, conditions = optimal_transport(

0 commit comments

Comments
 (0)