Skip to content

Commit bb068d2

Browse files
committed
implement compile_from_config and get_compile_config
1 parent 42fa035 commit bb068d2

File tree

4 files changed

+30
-1
lines changed

4 files changed

+30
-1
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def compile(
104104

105105
return super().compile(*args, **kwargs)
106106

107+
def compile_from_config(self, config):
108+
return self.compile(**deserialize(config))
109+
107110
def compute_metrics(
108111
self,
109112
inference_variables: Tensor,
@@ -213,6 +216,16 @@ def get_config(self):
213216

214217
return base_config | serialize(config)
215218

219+
def get_compile_config(self):
220+
base_config = super().get_compile_config() or {}
221+
222+
config = {
223+
"inference_metrics": self.inference_network._metrics,
224+
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
225+
}
226+
227+
return base_config | serialize(config)
228+
216229
def estimate(
217230
self,
218231
conditions: Mapping[str, np.ndarray],

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def compile(
118118

119119
return super().compile(*args, **kwargs)
120120

121+
def compile_from_config(self, config):
122+
return self.compile(**deserialize(config))
123+
121124
def compute_metrics(
122125
self,
123126
*,
@@ -262,6 +265,16 @@ def get_config(self):
262265

263266
return base_config | serialize(config)
264267

268+
def get_compile_config(self):
269+
base_config = super().get_compile_config() or {}
270+
271+
config = {
272+
"classifier_metrics": self.classifier_network._metrics,
273+
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
274+
}
275+
276+
return base_config | serialize(config)
277+
265278
def predict(
266279
self,
267280
*,

bayesflow/metrics/maximum_mean_discrepancy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import keras
44

5+
from bayesflow.utils.serialization import serializable
56
from .functional import maximum_mean_discrepancy
67

78

9+
@serializable
810
class MaximumMeanDiscrepancy(keras.Metric):
911
def __init__(
1012
self,

bayesflow/metrics/root_mean_squard_error.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from functools import partial
22
import keras
33

4-
4+
from bayesflow.utils.serialization import serializable
55
from .functional import root_mean_squared_error
66

77

8+
@serializable
89
class RootMeanSquaredError(keras.metrics.MeanMetricWrapper):
910
def __init__(self, name="root_mean_squared_error", dtype=None, **kwargs):
1011
fn = partial(root_mean_squared_error, **kwargs)

0 commit comments

Comments
 (0)