diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index dbd9eba0c..f0c1d68fa 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -104,6 +104,12 @@ def compile( return super().compile(*args, **kwargs) + def compile_from_config(self, config): + self.compile(**deserialize(config)) + if hasattr(self, "optimizer") and self.built: + # Create optimizer variables. + self.optimizer.build(self.trainable_variables) + def compute_metrics( self, inference_variables: Tensor, @@ -213,6 +219,16 @@ def get_config(self): return base_config | serialize(config) + def get_compile_config(self): + base_config = super().get_compile_config() or {} + + config = { + "inference_metrics": self.inference_network._metrics, + "summary_metrics": self.summary_network._metrics if self.summary_network is not None else None, + } + + return base_config | serialize(config) + def estimate( self, conditions: Mapping[str, np.ndarray], diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 1e26f00b0..03b377537 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -118,6 +118,12 @@ def compile( return super().compile(*args, **kwargs) + def compile_from_config(self, config): + self.compile(**deserialize(config)) + if hasattr(self, "optimizer") and self.built: + # Create optimizer variables. + self.optimizer.build(self.trainable_variables) + def compute_metrics( self, *, @@ -262,6 +268,16 @@ def get_config(self): return base_config | serialize(config) + def get_compile_config(self): + base_config = super().get_compile_config() or {} + + config = { + "classifier_metrics": self.classifier_network._metrics, + "summary_metrics": self.summary_network._metrics if self.summary_network is not None else None, + } + + return base_config | serialize(config) + def predict( self, *, diff --git a/bayesflow/metrics/maximum_mean_discrepancy.py b/bayesflow/metrics/maximum_mean_discrepancy.py index 64b8c35a0..37af44fd4 100644 --- a/bayesflow/metrics/maximum_mean_discrepancy.py +++ b/bayesflow/metrics/maximum_mean_discrepancy.py @@ -2,9 +2,11 @@ import keras +from bayesflow.utils.serialization import serializable from .functional import maximum_mean_discrepancy +@serializable class MaximumMeanDiscrepancy(keras.Metric): def __init__( self, diff --git a/bayesflow/metrics/root_mean_squard_error.py b/bayesflow/metrics/root_mean_squard_error.py index 13e724c14..97de62e6a 100644 --- a/bayesflow/metrics/root_mean_squard_error.py +++ b/bayesflow/metrics/root_mean_squard_error.py @@ -1,10 +1,11 @@ from functools import partial import keras - +from bayesflow.utils.serialization import serializable from .functional import root_mean_squared_error +@serializable class RootMeanSquaredError(keras.metrics.MeanMetricWrapper): def __init__(self, name="root_mean_squared_error", dtype=None, **kwargs): fn = partial(root_mean_squared_error, **kwargs)