Skip to content

Commit de8e1cb

Browse files
authored
implement compile_from_config and get_compile_config (#442)
* implement compile_from_config and get_compile_config * add optimizer build to compile_from_config
1 parent c638124 commit de8e1cb

File tree

4 files changed

+36
-1
lines changed

4 files changed

+36
-1
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ def compile(
104104

105105
return super().compile(*args, **kwargs)
106106

107+
def compile_from_config(self, config):
108+
self.compile(**deserialize(config))
109+
if hasattr(self, "optimizer") and self.built:
110+
# Create optimizer variables.
111+
self.optimizer.build(self.trainable_variables)
112+
107113
def compute_metrics(
108114
self,
109115
inference_variables: Tensor,
@@ -213,6 +219,16 @@ def get_config(self):
213219

214220
return base_config | serialize(config)
215221

222+
def get_compile_config(self):
223+
base_config = super().get_compile_config() or {}
224+
225+
config = {
226+
"inference_metrics": self.inference_network._metrics,
227+
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
228+
}
229+
230+
return base_config | serialize(config)
231+
216232
def estimate(
217233
self,
218234
conditions: Mapping[str, np.ndarray],

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ def compile(
118118

119119
return super().compile(*args, **kwargs)
120120

121+
def compile_from_config(self, config):
122+
self.compile(**deserialize(config))
123+
if hasattr(self, "optimizer") and self.built:
124+
# Create optimizer variables.
125+
self.optimizer.build(self.trainable_variables)
126+
121127
def compute_metrics(
122128
self,
123129
*,
@@ -262,6 +268,16 @@ def get_config(self):
262268

263269
return base_config | serialize(config)
264270

271+
def get_compile_config(self):
272+
base_config = super().get_compile_config() or {}
273+
274+
config = {
275+
"classifier_metrics": self.classifier_network._metrics,
276+
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
277+
}
278+
279+
return base_config | serialize(config)
280+
265281
def predict(
266282
self,
267283
*,

bayesflow/metrics/maximum_mean_discrepancy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import keras
44

5+
from bayesflow.utils.serialization import serializable
56
from .functional import maximum_mean_discrepancy
67

78

9+
@serializable
810
class MaximumMeanDiscrepancy(keras.Metric):
911
def __init__(
1012
self,

bayesflow/metrics/root_mean_squard_error.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from functools import partial
22
import keras
33

4-
4+
from bayesflow.utils.serialization import serializable
55
from .functional import root_mean_squared_error
66

77

8+
@serializable
89
class RootMeanSquaredError(keras.metrics.MeanMetricWrapper):
910
def __init__(self, name="root_mean_squared_error", dtype=None, **kwargs):
1011
fn = partial(root_mean_squared_error, **kwargs)

0 commit comments

Comments
 (0)