From 9b4f14239d8e873289ead309353792a640e9c3b1 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 2 May 2025 17:27:25 +0000 Subject: [PATCH 1/6] serialization: apply new scheme for `package` (breaking) - introduces new policy for consistent naming for serilization (see #451 for a discussion): standard is the path of a module a class resides in, trucated at depth to. So for all class in bayesflow.networks, we set package="bayesflow.networks", even if the live in the bayesflow.networks.mlp submodule. - The `serializable` decorator checks this and errors if this is not followed. The check can be disabled for certain cases (e.g., classes in the experimental module, that might eventually live somewhere else). - After this commit, previously saved models will not be loadable. As we introduced a bug regarding this anyway (#451), we will accept this and should inform users about it. - usage of direct calls to `keras.saving.register_keras_serializable` were replaced with our custom decorator. --- bayesflow/adapters/adapter.py | 2 +- bayesflow/adapters/transforms/as_set.py | 2 +- .../adapters/transforms/as_time_series.py | 2 +- bayesflow/adapters/transforms/broadcast.py | 2 +- bayesflow/adapters/transforms/concatenate.py | 2 +- bayesflow/adapters/transforms/constrain.py | 2 +- .../adapters/transforms/convert_dtype.py | 2 +- bayesflow/adapters/transforms/drop.py | 2 +- .../transforms/elementwise_transform.py | 2 +- bayesflow/adapters/transforms/expand_dims.py | 2 +- .../adapters/transforms/filter_transform.py | 2 +- bayesflow/adapters/transforms/keep.py | 2 +- bayesflow/adapters/transforms/log.py | 2 +- .../adapters/transforms/map_transform.py | 2 +- .../adapters/transforms/numpy_transform.py | 2 +- bayesflow/adapters/transforms/one_hot.py | 2 +- bayesflow/adapters/transforms/rename.py | 2 +- bayesflow/adapters/transforms/scale.py | 2 +- .../serializable_custom_transform.py | 3 +- bayesflow/adapters/transforms/shift.py | 2 +- bayesflow/adapters/transforms/split.py | 2 +- bayesflow/adapters/transforms/sqrt.py | 2 +- bayesflow/adapters/transforms/standardize.py | 2 +- bayesflow/adapters/transforms/to_array.py | 2 +- bayesflow/adapters/transforms/to_dict.py | 2 +- bayesflow/adapters/transforms/transform.py | 2 +- .../approximators/continuous_approximator.py | 2 +- .../model_comparison_approximator.py | 2 +- bayesflow/approximators/point_approximator.py | 2 +- bayesflow/distributions/diagonal_normal.py | 2 +- bayesflow/distributions/diagonal_student_t.py | 2 +- bayesflow/distributions/distribution.py | 2 +- bayesflow/distributions/mixture.py | 2 +- bayesflow/experimental/cif/cif.py | 5 ++- .../experimental/cif/conditional_gaussian.py | 5 ++- .../continuous_time_consistency_model.py | 3 +- .../free_form_flow/free_form_flow.py | 3 +- bayesflow/experimental/resnet/dense_resnet.py | 3 +- bayesflow/experimental/resnet/double_conv.py | 3 +- .../experimental/resnet/double_linear.py | 3 +- bayesflow/experimental/resnet/resnet.py | 3 +- bayesflow/links/ordered.py | 4 +- bayesflow/links/ordered_quantiles.py | 4 +- bayesflow/links/positive_definite.py | 3 +- bayesflow/metrics/maximum_mean_discrepancy.py | 2 +- bayesflow/metrics/root_mean_squard_error.py | 2 +- .../consistency_models/consistency_model.py | 2 +- bayesflow/networks/coupling_flow/actnorm.py | 2 +- .../networks/coupling_flow/coupling_flow.py | 2 +- .../coupling_flow/couplings/dual_coupling.py | 2 +- .../couplings/single_coupling.py | 2 +- .../permutations/fixed_permutation.py | 2 +- .../coupling_flow/permutations/orthogonal.py | 2 +- .../coupling_flow/permutations/random.py | 2 +- .../coupling_flow/permutations/swap.py | 2 +- .../transforms/affine_transform.py | 2 +- .../transforms/spline_transform.py | 2 +- bayesflow/networks/deep_set/deep_set.py | 2 +- .../networks/deep_set/equivariant_layer.py | 2 +- .../networks/deep_set/invariant_layer.py | 2 +- .../networks/embeddings/fourier_embedding.py | 2 +- .../embeddings/recurrent_embedding.py | 2 +- bayesflow/networks/embeddings/time2vec.py | 2 +- .../networks/flow_matching/flow_matching.py | 2 +- bayesflow/networks/mlp/mlp.py | 2 +- bayesflow/networks/point_inference_network.py | 2 +- bayesflow/networks/residual/residual.py | 2 +- .../time_series_network/skip_recurrent.py | 2 +- .../time_series_network.py | 2 +- .../transformers/fusion_transformer.py | 2 +- bayesflow/networks/transformers/isab.py | 2 +- bayesflow/networks/transformers/mab.py | 2 +- bayesflow/networks/transformers/pma.py | 2 +- bayesflow/networks/transformers/sab.py | 2 +- .../networks/transformers/set_transformer.py | 2 +- .../transformers/time_series_transformer.py | 2 +- bayesflow/scores/mean_score.py | 3 +- bayesflow/scores/median_score.py | 3 +- bayesflow/scores/multivariate_normal_score.py | 2 +- bayesflow/scores/normed_difference_score.py | 2 +- .../scores/parametric_distribution_score.py | 3 +- bayesflow/scores/quantile_score.py | 2 +- bayesflow/utils/serialization.py | 37 ++++++++++++++----- bayesflow/wrappers/mamba/mamba.py | 2 +- bayesflow/wrappers/mamba/mamba_block.py | 2 +- .../test_utils/test_serialize_deserialize.py | 4 +- tests/test_workflows/conftest.py | 2 +- 87 files changed, 128 insertions(+), 104 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index ab6800d8a..a17a59d81 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -29,7 +29,7 @@ from .transforms.filter_transform import Predicate -@serializable +@serializable("bayesflow.adapters") class Adapter(MutableSequence[Transform]): """ Defines an adapter to apply various transforms to data. diff --git a/bayesflow/adapters/transforms/as_set.py b/bayesflow/adapters/transforms/as_set.py index f4d5bdfc5..903536bc4 100644 --- a/bayesflow/adapters/transforms/as_set.py +++ b/bayesflow/adapters/transforms/as_set.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class AsSet(ElementwiseTransform): """The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets. diff --git a/bayesflow/adapters/transforms/as_time_series.py b/bayesflow/adapters/transforms/as_time_series.py index 3f4d2a2c5..d7791352c 100644 --- a/bayesflow/adapters/transforms/as_time_series.py +++ b/bayesflow/adapters/transforms/as_time_series.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class AsTimeSeries(ElementwiseTransform): """The `.as_time_series` transform can be used to indicate that variables shall be treated as time series. diff --git a/bayesflow/adapters/transforms/broadcast.py b/bayesflow/adapters/transforms/broadcast.py index 8667ec0c7..646e1f72e 100644 --- a/bayesflow/adapters/transforms/broadcast.py +++ b/bayesflow/adapters/transforms/broadcast.py @@ -6,7 +6,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Broadcast(Transform): """ Broadcasts arrays or scalars to the shape of a given other array. diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index 91ea9178b..ac3700616 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -7,7 +7,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Concatenate(Transform): """Concatenate multiple arrays into a new key. Used to specify how data variables should be treated by the network. diff --git a/bayesflow/adapters/transforms/constrain.py b/bayesflow/adapters/transforms/constrain.py index a4ca0be25..d01211dfc 100644 --- a/bayesflow/adapters/transforms/constrain.py +++ b/bayesflow/adapters/transforms/constrain.py @@ -11,7 +11,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Constrain(ElementwiseTransform): """ Constrains neural network predictions of a data variable to specified bounds. diff --git a/bayesflow/adapters/transforms/convert_dtype.py b/bayesflow/adapters/transforms/convert_dtype.py index e68815269..8cd21b4cc 100644 --- a/bayesflow/adapters/transforms/convert_dtype.py +++ b/bayesflow/adapters/transforms/convert_dtype.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class ConvertDType(ElementwiseTransform): """ Default transform used to convert all floats from float64 to float32 to be in line with keras framework. diff --git a/bayesflow/adapters/transforms/drop.py b/bayesflow/adapters/transforms/drop.py index 91dcd6a28..5073027e6 100644 --- a/bayesflow/adapters/transforms/drop.py +++ b/bayesflow/adapters/transforms/drop.py @@ -5,7 +5,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Drop(Transform): """ Transform to drop variables from further calculation. diff --git a/bayesflow/adapters/transforms/elementwise_transform.py b/bayesflow/adapters/transforms/elementwise_transform.py index 7d603d517..020301749 100644 --- a/bayesflow/adapters/transforms/elementwise_transform.py +++ b/bayesflow/adapters/transforms/elementwise_transform.py @@ -3,7 +3,7 @@ from bayesflow.utils.serialization import serializable, deserialize -@serializable +@serializable("bayesflow.adapters") class ElementwiseTransform: """Base class on which other transforms are based""" diff --git a/bayesflow/adapters/transforms/expand_dims.py b/bayesflow/adapters/transforms/expand_dims.py index e44d133b8..0f4151d37 100644 --- a/bayesflow/adapters/transforms/expand_dims.py +++ b/bayesflow/adapters/transforms/expand_dims.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class ExpandDims(ElementwiseTransform): """ Expand the shape of an array. diff --git a/bayesflow/adapters/transforms/filter_transform.py b/bayesflow/adapters/transforms/filter_transform.py index 7eccf370b..4dc2c8008 100644 --- a/bayesflow/adapters/transforms/filter_transform.py +++ b/bayesflow/adapters/transforms/filter_transform.py @@ -14,7 +14,7 @@ def __call__(self, key: str, value: np.ndarray, inverse: bool) -> bool: raise NotImplementedError -@serializable +@serializable("bayesflow.adapters") class FilterTransform(Transform): """ Implements a transform that applies a different transform on a subset of the data. diff --git a/bayesflow/adapters/transforms/keep.py b/bayesflow/adapters/transforms/keep.py index 56f395166..c69d01ca3 100644 --- a/bayesflow/adapters/transforms/keep.py +++ b/bayesflow/adapters/transforms/keep.py @@ -5,7 +5,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Keep(Transform): """ Name the data parameters that should be kept for futher calculation. diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index d5f559b4f..a42c43ef0 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Log(ElementwiseTransform): """Log transforms a variable. diff --git a/bayesflow/adapters/transforms/map_transform.py b/bayesflow/adapters/transforms/map_transform.py index 5da8292af..15c5c945d 100644 --- a/bayesflow/adapters/transforms/map_transform.py +++ b/bayesflow/adapters/transforms/map_transform.py @@ -6,7 +6,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class MapTransform(Transform): """ Implements a transform that applies a set of elementwise transforms diff --git a/bayesflow/adapters/transforms/numpy_transform.py b/bayesflow/adapters/transforms/numpy_transform.py index 29d25dc67..a19216dd2 100644 --- a/bayesflow/adapters/transforms/numpy_transform.py +++ b/bayesflow/adapters/transforms/numpy_transform.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class NumpyTransform(ElementwiseTransform): """ A class to apply element-wise transformations using plain NumPy functions. diff --git a/bayesflow/adapters/transforms/one_hot.py b/bayesflow/adapters/transforms/one_hot.py index e097a28f9..bbed8bf5d 100644 --- a/bayesflow/adapters/transforms/one_hot.py +++ b/bayesflow/adapters/transforms/one_hot.py @@ -6,7 +6,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class OneHot(ElementwiseTransform): """ Changes data to be one-hot encoded. diff --git a/bayesflow/adapters/transforms/rename.py b/bayesflow/adapters/transforms/rename.py index 746ef5a80..bec3388b0 100644 --- a/bayesflow/adapters/transforms/rename.py +++ b/bayesflow/adapters/transforms/rename.py @@ -3,7 +3,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Rename(Transform): """ Transform to rename keys in data dictionary. Useful to rename variables to match those required by diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py index 96b2ff927..d7c1aa2a7 100644 --- a/bayesflow/adapters/transforms/scale.py +++ b/bayesflow/adapters/transforms/scale.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Scale(ElementwiseTransform): def __init__(self, scale: np.typing.ArrayLike): self.scale = np.array(scale) diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py index 75d588afd..fbfc4615b 100644 --- a/bayesflow/adapters/transforms/serializable_custom_transform.py +++ b/bayesflow/adapters/transforms/serializable_custom_transform.py @@ -2,11 +2,12 @@ import numpy as np from keras.saving import ( deserialize_keras_object as deserialize, - register_keras_serializable as serializable, serialize_keras_object as serialize, get_registered_name, get_registered_object, ) + +from bayesflow.utils.serialization import serializable from .elementwise_transform import ElementwiseTransform from ...utils import filter_kwargs import inspect diff --git a/bayesflow/adapters/transforms/shift.py b/bayesflow/adapters/transforms/shift.py index b7c9659d2..5923b4e49 100644 --- a/bayesflow/adapters/transforms/shift.py +++ b/bayesflow/adapters/transforms/shift.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Shift(ElementwiseTransform): def __init__(self, shift: np.typing.ArrayLike): self.shift = np.array(shift) diff --git a/bayesflow/adapters/transforms/split.py b/bayesflow/adapters/transforms/split.py index 919db4e08..4c0ae9f65 100644 --- a/bayesflow/adapters/transforms/split.py +++ b/bayesflow/adapters/transforms/split.py @@ -6,7 +6,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Split(Transform): """This is the effective inverse of the :py:class:`~Concatenate` Transform. diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index 4ef1370dc..bcfe49136 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Sqrt(ElementwiseTransform): """Square-root transform a variable. diff --git a/bayesflow/adapters/transforms/standardize.py b/bayesflow/adapters/transforms/standardize.py index 9699819b9..a1c3c5a3d 100644 --- a/bayesflow/adapters/transforms/standardize.py +++ b/bayesflow/adapters/transforms/standardize.py @@ -7,7 +7,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Standardize(ElementwiseTransform): """ Transform that when applied standardizes data using typical z-score standardization diff --git a/bayesflow/adapters/transforms/to_array.py b/bayesflow/adapters/transforms/to_array.py index 9d5381ca0..fe1b82f2d 100644 --- a/bayesflow/adapters/transforms/to_array.py +++ b/bayesflow/adapters/transforms/to_array.py @@ -7,7 +7,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class ToArray(ElementwiseTransform): """ Checks provided data for any non-arrays and converts them to numpy arrays. diff --git a/bayesflow/adapters/transforms/to_dict.py b/bayesflow/adapters/transforms/to_dict.py index 6babb2a40..cfc4ec00d 100644 --- a/bayesflow/adapters/transforms/to_dict.py +++ b/bayesflow/adapters/transforms/to_dict.py @@ -6,7 +6,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class ToDict(Transform): """Convert non-dict batches (e.g., pandas.DataFrame) to dict batches""" diff --git a/bayesflow/adapters/transforms/transform.py b/bayesflow/adapters/transforms/transform.py index ed3058e15..0bc6331bc 100644 --- a/bayesflow/adapters/transforms/transform.py +++ b/bayesflow/adapters/transforms/transform.py @@ -3,7 +3,7 @@ from bayesflow.utils.serialization import serializable, deserialize -@serializable +@serializable("bayesflow.adapters") class Transform: """ Base class on which other transforms are based diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 834521d4b..3e43a8917 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -13,7 +13,7 @@ from .approximator import Approximator -@serializable +@serializable("bayesflow.approximators") class ContinuousApproximator(Approximator): """ Defines a workflow for performing fast posterior or likelihood inference. diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 028e8837a..86af4ee33 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -14,7 +14,7 @@ from .approximator import Approximator -@serializable +@serializable("bayesflow.approximators") class ModelComparisonApproximator(Approximator): """ Defines an approximator for model (simulator) comparison, where the (discrete) posterior model probabilities are diff --git a/bayesflow/approximators/point_approximator.py b/bayesflow/approximators/point_approximator.py index 1e407e2a6..b3d90781c 100644 --- a/bayesflow/approximators/point_approximator.py +++ b/bayesflow/approximators/point_approximator.py @@ -11,7 +11,7 @@ from .continuous_approximator import ContinuousApproximator -@serializable +@serializable("bayesflow.approximators") class PointApproximator(ContinuousApproximator): """ A workflow for fast amortized point estimation of a conditional distribution. diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 98a127b1c..f8d93b945 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -12,7 +12,7 @@ from .distribution import Distribution -@serializable +@serializable("bayesflow.distributions") class DiagonalNormal(Distribution): """Implements a backend-agnostic diagonal Gaussian distribution.""" diff --git a/bayesflow/distributions/diagonal_student_t.py b/bayesflow/distributions/diagonal_student_t.py index cd32a67fb..98e3fb7eb 100644 --- a/bayesflow/distributions/diagonal_student_t.py +++ b/bayesflow/distributions/diagonal_student_t.py @@ -13,7 +13,7 @@ from .distribution import Distribution -@serializable +@serializable("bayesflow.distributions") class DiagonalStudentT(Distribution): """Implements a backend-agnostic diagonal Student-t distribution.""" diff --git a/bayesflow/distributions/distribution.py b/bayesflow/distributions/distribution.py index 1d3a83962..3689f0d9f 100644 --- a/bayesflow/distributions/distribution.py +++ b/bayesflow/distributions/distribution.py @@ -5,7 +5,7 @@ from bayesflow.utils.serialization import serializable, deserialize -@serializable +@serializable("bayesflow.distributions") class Distribution(keras.Layer): def __init__(self, **kwargs): super().__init__(**layer_kwargs(kwargs)) diff --git a/bayesflow/distributions/mixture.py b/bayesflow/distributions/mixture.py index d7f6bd758..a7bf2ea27 100644 --- a/bayesflow/distributions/mixture.py +++ b/bayesflow/distributions/mixture.py @@ -11,7 +11,7 @@ from bayesflow.distributions import Distribution -@serializable +@serializable("bayesflow.distributions") class Mixture(Distribution): """Utility class for a backend-agnostic mixture distributions.""" diff --git a/bayesflow/experimental/cif/cif.py b/bayesflow/experimental/cif/cif.py index 8742501b3..c77e79366 100644 --- a/bayesflow/experimental/cif/cif.py +++ b/bayesflow/experimental/cif/cif.py @@ -1,7 +1,7 @@ import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor +from bayesflow.utils.serialization import serializable from bayesflow.networks.inference_network import InferenceNetwork from bayesflow.networks.coupling_flow import CouplingFlow @@ -9,7 +9,8 @@ from .conditional_gaussian import ConditionalGaussian -@serializable(package="bayesflow.networks") +# disable module check, use potential module after moving from experimental +@serializable(package="bayesflow.networks", disable_module_check=True) class CIF(InferenceNetwork): """Implements a continuously indexed flow (CIF) with a `CouplingFlow` bijection and `ConditionalGaussian` distributions p and q. Improves on diff --git a/bayesflow/experimental/cif/conditional_gaussian.py b/bayesflow/experimental/cif/conditional_gaussian.py index d11f3fb65..e1a5ac7f7 100644 --- a/bayesflow/experimental/cif/conditional_gaussian.py +++ b/bayesflow/experimental/cif/conditional_gaussian.py @@ -1,13 +1,14 @@ import keras -from keras.saving import register_keras_serializable import numpy as np from bayesflow.networks.mlp import MLP from bayesflow.types import Shape, Tensor from bayesflow.utils import layer_kwargs +from bayesflow.utils.serialization import serializable -@register_keras_serializable(package="bayesflow.networks.cif") +# disable module check, use potential module after moving from experimental +@serializable(package="bayesflow.networks", disable_module_check=True) class ConditionalGaussian(keras.Layer): """Implements a conditional gaussian distribution with neural networks for the means and standard deviations respectively. Bulit in reference to [1]. diff --git a/bayesflow/experimental/continuous_time_consistency_model.py b/bayesflow/experimental/continuous_time_consistency_model.py index 54417cd07..b1c751454 100644 --- a/bayesflow/experimental/continuous_time_consistency_model.py +++ b/bayesflow/experimental/continuous_time_consistency_model.py @@ -22,7 +22,8 @@ from bayesflow.networks.embeddings import FourierEmbedding -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class ContinuousTimeConsistencyModel(InferenceNetwork): """Implements an sCM (simple, stable, and scalable Consistency Model) with continous-time Consistency Training (CT) as described in [1]. diff --git a/bayesflow/experimental/free_form_flow/free_form_flow.py b/bayesflow/experimental/free_form_flow/free_form_flow.py index 61937d56f..12bb97b93 100644 --- a/bayesflow/experimental/free_form_flow/free_form_flow.py +++ b/bayesflow/experimental/free_form_flow/free_form_flow.py @@ -19,7 +19,8 @@ from bayesflow.networks import InferenceNetwork -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class FreeFormFlow(InferenceNetwork): """Implements a dimensionality-preserving Free-form Flow. Incorporates ideas from [1-2]. diff --git a/bayesflow/experimental/resnet/dense_resnet.py b/bayesflow/experimental/resnet/dense_resnet.py index fa380969f..93ff59d3f 100644 --- a/bayesflow/experimental/resnet/dense_resnet.py +++ b/bayesflow/experimental/resnet/dense_resnet.py @@ -8,7 +8,8 @@ from .double_linear import DoubleLinear -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class DenseResNet(keras.Sequential): """ Implements the fully-connected analogue of the ResNet architecture. diff --git a/bayesflow/experimental/resnet/double_conv.py b/bayesflow/experimental/resnet/double_conv.py index c70e37323..a2b6bbc88 100644 --- a/bayesflow/experimental/resnet/double_conv.py +++ b/bayesflow/experimental/resnet/double_conv.py @@ -5,7 +5,8 @@ from bayesflow.utils.serialization import deserialize, serializable, serialize -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class DoubleConv(keras.Sequential): def __init__( self, diff --git a/bayesflow/experimental/resnet/double_linear.py b/bayesflow/experimental/resnet/double_linear.py index e2138c8b0..aae72fa39 100644 --- a/bayesflow/experimental/resnet/double_linear.py +++ b/bayesflow/experimental/resnet/double_linear.py @@ -5,7 +5,8 @@ from bayesflow.utils.serialization import deserialize, serializable, serialize -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class DoubleLinear(keras.Sequential): def __init__( self, diff --git a/bayesflow/experimental/resnet/resnet.py b/bayesflow/experimental/resnet/resnet.py index 07f1f2cda..862e0ac98 100644 --- a/bayesflow/experimental/resnet/resnet.py +++ b/bayesflow/experimental/resnet/resnet.py @@ -8,7 +8,8 @@ from .double_conv import DoubleConv -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class ResNet(keras.Sequential): """ Implements the ResNet architecture. diff --git a/bayesflow/links/ordered.py b/bayesflow/links/ordered.py index 77545b6f8..315bcdce4 100644 --- a/bayesflow/links/ordered.py +++ b/bayesflow/links/ordered.py @@ -1,11 +1,11 @@ import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.utils import layer_kwargs from bayesflow.utils.decorators import sanitize_input_shape +from bayesflow.utils.serialization import serializable -@serializable(package="links.ordered") +@serializable(package="bayesflow.links") class Ordered(keras.Layer): """Activation function to link to a tensor which is monotonously increasing along a specified axis.""" diff --git a/bayesflow/links/ordered_quantiles.py b/bayesflow/links/ordered_quantiles.py index d4f4caba2..09e565b5d 100644 --- a/bayesflow/links/ordered_quantiles.py +++ b/bayesflow/links/ordered_quantiles.py @@ -1,14 +1,14 @@ import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.utils import layer_kwargs, logging +from bayesflow.utils.serialization import serializable from collections.abc import Sequence from .ordered import Ordered -@serializable(package="links.ordered_quantiles") +@serializable(package="bayesflow.links") class OrderedQuantiles(Ordered): """Activation function to link to monotonously increasing quantile estimates.""" diff --git a/bayesflow/links/positive_definite.py b/bayesflow/links/positive_definite.py index 28c937f86..112b37305 100644 --- a/bayesflow/links/positive_definite.py +++ b/bayesflow/links/positive_definite.py @@ -1,9 +1,8 @@ import keras -from keras.saving import register_keras_serializable as serializable - from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs, fill_triangular_matrix +from bayesflow.utils.serialization import serializable @serializable(package="bayesflow.links") diff --git a/bayesflow/metrics/maximum_mean_discrepancy.py b/bayesflow/metrics/maximum_mean_discrepancy.py index 37af44fd4..de4ee32f1 100644 --- a/bayesflow/metrics/maximum_mean_discrepancy.py +++ b/bayesflow/metrics/maximum_mean_discrepancy.py @@ -6,7 +6,7 @@ from .functional import maximum_mean_discrepancy -@serializable +@serializable("bayesflow.metrics") class MaximumMeanDiscrepancy(keras.Metric): def __init__( self, diff --git a/bayesflow/metrics/root_mean_squard_error.py b/bayesflow/metrics/root_mean_squard_error.py index 97de62e6a..8827095e9 100644 --- a/bayesflow/metrics/root_mean_squard_error.py +++ b/bayesflow/metrics/root_mean_squard_error.py @@ -5,7 +5,7 @@ from .functional import root_mean_squared_error -@serializable +@serializable("bayesflow.metrics") class RootMeanSquaredError(keras.metrics.MeanMetricWrapper): def __init__(self, name="root_mean_squared_error", dtype=None, **kwargs): fn = partial(root_mean_squared_error, **kwargs) diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index b8d4c56ed..8d36c1736 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -12,7 +12,7 @@ from ..inference_network import InferenceNetwork -@serializable +@serializable("bayesflow.networks") class ConsistencyModel(InferenceNetwork): """Implements a Consistency Model with Consistency Training (CT) a described in [1-2]. The adaptations to CT described in [2] were taken into account in our implementation for ABI [3]. diff --git a/bayesflow/networks/coupling_flow/actnorm.py b/bayesflow/networks/coupling_flow/actnorm.py index 5221caea1..81cdc425d 100644 --- a/bayesflow/networks/coupling_flow/actnorm.py +++ b/bayesflow/networks/coupling_flow/actnorm.py @@ -6,7 +6,7 @@ from .invertible_layer import InvertibleLayer -@serializable +@serializable("bayesflow.networks") class ActNorm(InvertibleLayer): """Implements an Activation Normalization (ActNorm) Layer. Activation Normalization is learned invertible normalization, using a scale (s) and a bias (b) vector:: diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index 203962b0f..28954d7d2 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -13,7 +13,7 @@ from ..inference_network import InferenceNetwork -@serializable +@serializable("bayesflow.networks") class CouplingFlow(InferenceNetwork): """Implements a coupling flow as a sequence of dual couplings with permutations and activation normalization. Incorporates ideas from [1-5]. diff --git a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py index 67db6e269..462bc02d6 100644 --- a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py @@ -9,7 +9,7 @@ from ..invertible_layer import InvertibleLayer -@serializable +@serializable("bayesflow.networks") class DualCoupling(InvertibleLayer): def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwargs): super().__init__(**kwargs) diff --git a/bayesflow/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/networks/coupling_flow/couplings/single_coupling.py index d2703cfa7..7bd6aaf3b 100644 --- a/bayesflow/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/single_coupling.py @@ -8,7 +8,7 @@ from ..transforms import find_transform -@serializable +@serializable("bayesflow.networks") class SingleCoupling(InvertibleLayer): """ Implements a single coupling layer as a composition of a subnet and a transform. diff --git a/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py b/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py index d93f27ae5..68591a172 100644 --- a/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py +++ b/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py @@ -6,7 +6,7 @@ from ..invertible_layer import InvertibleLayer -@serializable +@serializable("bayesflow.networks") class FixedPermutation(InvertibleLayer): """ Interface class for permutations with no learnable parameters. Child classes should diff --git a/bayesflow/networks/coupling_flow/permutations/orthogonal.py b/bayesflow/networks/coupling_flow/permutations/orthogonal.py index a28fe7965..54bfbf901 100644 --- a/bayesflow/networks/coupling_flow/permutations/orthogonal.py +++ b/bayesflow/networks/coupling_flow/permutations/orthogonal.py @@ -6,7 +6,7 @@ from ..invertible_layer import InvertibleLayer -@serializable +@serializable("bayesflow.networks") class OrthogonalPermutation(InvertibleLayer): """Implements a learnable orthogonal transformation according to [1]. Can be used as an alternative to a fixed ``Permutation`` layer. diff --git a/bayesflow/networks/coupling_flow/permutations/random.py b/bayesflow/networks/coupling_flow/permutations/random.py index 82d7f39ff..522d48c63 100644 --- a/bayesflow/networks/coupling_flow/permutations/random.py +++ b/bayesflow/networks/coupling_flow/permutations/random.py @@ -6,7 +6,7 @@ from .fixed_permutation import FixedPermutation -@serializable +@serializable("bayesflow.networks") class RandomPermutation(FixedPermutation): # noinspection PyMethodOverriding def build(self, xz_shape: Shape, **kwargs) -> None: diff --git a/bayesflow/networks/coupling_flow/permutations/swap.py b/bayesflow/networks/coupling_flow/permutations/swap.py index c5f707a1a..bb7f641b9 100644 --- a/bayesflow/networks/coupling_flow/permutations/swap.py +++ b/bayesflow/networks/coupling_flow/permutations/swap.py @@ -6,7 +6,7 @@ from .fixed_permutation import FixedPermutation -@serializable +@serializable("bayesflow.networks") class Swap(FixedPermutation): def build(self, xz_shape: Shape, **kwargs) -> None: shift = xz_shape[-1] // 2 diff --git a/bayesflow/networks/coupling_flow/transforms/affine_transform.py b/bayesflow/networks/coupling_flow/transforms/affine_transform.py index 9e8c4a9e1..1d66b0bfb 100644 --- a/bayesflow/networks/coupling_flow/transforms/affine_transform.py +++ b/bayesflow/networks/coupling_flow/transforms/affine_transform.py @@ -7,7 +7,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.networks") class AffineTransform(Transform): def __init__(self, clamp: bool = True, **kwargs): super().__init__(**kwargs) diff --git a/bayesflow/networks/coupling_flow/transforms/spline_transform.py b/bayesflow/networks/coupling_flow/transforms/spline_transform.py index d5b0cf4b3..28b9c4415 100644 --- a/bayesflow/networks/coupling_flow/transforms/spline_transform.py +++ b/bayesflow/networks/coupling_flow/transforms/spline_transform.py @@ -10,7 +10,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.networks") class SplineTransform(Transform): def __init__( self, diff --git a/bayesflow/networks/deep_set/deep_set.py b/bayesflow/networks/deep_set/deep_set.py index 633a1508b..6cb6fb927 100644 --- a/bayesflow/networks/deep_set/deep_set.py +++ b/bayesflow/networks/deep_set/deep_set.py @@ -11,7 +11,7 @@ from ..summary_network import SummaryNetwork -@serializable +@serializable("bayesflow.networks") class DeepSet(SummaryNetwork): """Implements a deep set encoder introduced in [1] for learning permutation-invariant representations of set-based data, as generated by exchangeable models. diff --git a/bayesflow/networks/deep_set/equivariant_layer.py b/bayesflow/networks/deep_set/equivariant_layer.py index 81bd62f58..9131e7b4a 100644 --- a/bayesflow/networks/deep_set/equivariant_layer.py +++ b/bayesflow/networks/deep_set/equivariant_layer.py @@ -13,7 +13,7 @@ from .invariant_layer import InvariantLayer -@serializable +@serializable("bayesflow.networks") class EquivariantLayer(keras.Layer): """Implements an equivariant module performing an equivariant transform. diff --git a/bayesflow/networks/deep_set/invariant_layer.py b/bayesflow/networks/deep_set/invariant_layer.py index d1b6a26f9..5bc6313c0 100644 --- a/bayesflow/networks/deep_set/invariant_layer.py +++ b/bayesflow/networks/deep_set/invariant_layer.py @@ -11,7 +11,7 @@ from ..mlp import MLP -@serializable +@serializable("bayesflow.networks") class InvariantLayer(keras.Layer): """Implements an invariant module performing a permutation-invariant transform. diff --git a/bayesflow/networks/embeddings/fourier_embedding.py b/bayesflow/networks/embeddings/fourier_embedding.py index 65b5938d7..21924ee60 100644 --- a/bayesflow/networks/embeddings/fourier_embedding.py +++ b/bayesflow/networks/embeddings/fourier_embedding.py @@ -7,7 +7,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.networks") class FourierEmbedding(keras.Layer): """Implements a Fourier projection with normally distributed frequencies.""" diff --git a/bayesflow/networks/embeddings/recurrent_embedding.py b/bayesflow/networks/embeddings/recurrent_embedding.py index df7c00f32..3fa82868d 100644 --- a/bayesflow/networks/embeddings/recurrent_embedding.py +++ b/bayesflow/networks/embeddings/recurrent_embedding.py @@ -6,7 +6,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.networks") class RecurrentEmbedding(keras.Layer): """Implements a recurrent network for flexibly embedding time vectors.""" diff --git a/bayesflow/networks/embeddings/time2vec.py b/bayesflow/networks/embeddings/time2vec.py index 4c9c3a87f..b52ca77d8 100644 --- a/bayesflow/networks/embeddings/time2vec.py +++ b/bayesflow/networks/embeddings/time2vec.py @@ -5,7 +5,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.networks") class Time2Vec(keras.Layer): """ Implements the Time2Vec learnbale embedding from [1]. diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 3c0190467..797d4c62d 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -19,7 +19,7 @@ from ..inference_network import InferenceNetwork -@serializable +@serializable("bayesflow.networks") class FlowMatching(InferenceNetwork): """Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas incorporated from [1-3]. diff --git a/bayesflow/networks/mlp/mlp.py b/bayesflow/networks/mlp/mlp.py index 1ac11fe1a..11dcdca2b 100644 --- a/bayesflow/networks/mlp/mlp.py +++ b/bayesflow/networks/mlp/mlp.py @@ -9,7 +9,7 @@ from ..residual import Residual -@serializable +@serializable("bayesflow.networks") class MLP(keras.Sequential): """ Implements a simple configurable MLP with optional residual connections and dropout. diff --git a/bayesflow/networks/point_inference_network.py b/bayesflow/networks/point_inference_network.py index 63094a2a8..6e823ba14 100644 --- a/bayesflow/networks/point_inference_network.py +++ b/bayesflow/networks/point_inference_network.py @@ -7,7 +7,7 @@ from bayesflow.utils.decorators import allow_batch_size -@serializable(package="networks.point_inference_network") +@serializable(package="bayesflow.networks") class PointInferenceNetwork(keras.Layer): """Implements point estimation for user specified scoring rules by a shared feed forward architecture with separate heads for each scoring rule. diff --git a/bayesflow/networks/residual/residual.py b/bayesflow/networks/residual/residual.py index f2ca54b51..edf32782c 100644 --- a/bayesflow/networks/residual/residual.py +++ b/bayesflow/networks/residual/residual.py @@ -7,7 +7,7 @@ from bayesflow.utils.serialization import deserialize, serializable, serialize -@serializable +@serializable("bayesflow.networks") class Residual(keras.Sequential): def __init__(self, *layers: keras.Layer, **kwargs): if len(layers) == 1 and isinstance(layers[0], Sequence): diff --git a/bayesflow/networks/time_series_network/skip_recurrent.py b/bayesflow/networks/time_series_network/skip_recurrent.py index 9b2c06c0d..23dee5156 100644 --- a/bayesflow/networks/time_series_network/skip_recurrent.py +++ b/bayesflow/networks/time_series_network/skip_recurrent.py @@ -6,7 +6,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.networks") class SkipRecurrentNet(keras.Layer): """ Implements a Skip recurrent layer as described in [1], allowing a more flexible recurrent backbone diff --git a/bayesflow/networks/time_series_network/time_series_network.py b/bayesflow/networks/time_series_network/time_series_network.py index 354806f6c..7a96a099d 100644 --- a/bayesflow/networks/time_series_network/time_series_network.py +++ b/bayesflow/networks/time_series_network/time_series_network.py @@ -7,7 +7,7 @@ from ..summary_network import SummaryNetwork -@serializable +@serializable("bayesflow.networks") class TimeSeriesNetwork(SummaryNetwork): """ Implements a LSTNet Architecture as described in [1] diff --git a/bayesflow/networks/transformers/fusion_transformer.py b/bayesflow/networks/transformers/fusion_transformer.py index 1821c25d2..f416957fb 100644 --- a/bayesflow/networks/transformers/fusion_transformer.py +++ b/bayesflow/networks/transformers/fusion_transformer.py @@ -10,7 +10,7 @@ from .mab import MultiHeadAttentionBlock -@serializable +@serializable("bayesflow.networks") class FusionTransformer(SummaryNetwork): """Implements a more flexible version of the TimeSeriesTransformer that applies a series of self-attention layers followed by cross-attention between the representation and a learnable template summarized via a recurrent net.""" diff --git a/bayesflow/networks/transformers/isab.py b/bayesflow/networks/transformers/isab.py index ae1242469..03f15a561 100644 --- a/bayesflow/networks/transformers/isab.py +++ b/bayesflow/networks/transformers/isab.py @@ -7,7 +7,7 @@ from .mab import MultiHeadAttentionBlock -@serializable +@serializable("bayesflow.networks") class InducedSetAttentionBlock(keras.Layer): """Implements the ISAB block from [1] which represents learnable self-attention specifically designed to deal with large sets via a learnable set of "inducing points". diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index 8f0e3f881..5bd7c9dff 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -8,7 +8,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.networks") class MultiHeadAttentionBlock(keras.Layer): """Implements the MAB block from [1] which represents learnable cross-attention. diff --git a/bayesflow/networks/transformers/pma.py b/bayesflow/networks/transformers/pma.py index 956c85b48..bdcb2f983 100644 --- a/bayesflow/networks/transformers/pma.py +++ b/bayesflow/networks/transformers/pma.py @@ -10,7 +10,7 @@ from .mab import MultiHeadAttentionBlock -@serializable +@serializable("bayesflow.networks") class PoolingByMultiHeadAttention(keras.Layer): """Implements the pooling with multi-head attention (PMA) block from [1] which represents a permutation-invariant encoder for set-based inputs. diff --git a/bayesflow/networks/transformers/sab.py b/bayesflow/networks/transformers/sab.py index a447d92a2..276383dfd 100644 --- a/bayesflow/networks/transformers/sab.py +++ b/bayesflow/networks/transformers/sab.py @@ -7,7 +7,7 @@ from .mab import MultiHeadAttentionBlock -@serializable +@serializable("bayesflow.networks") class SetAttentionBlock(MultiHeadAttentionBlock): """Implements the SAB block from [1] which represents learnable self-attention. diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 6c0ab0efc..256d9e54d 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -11,7 +11,7 @@ from .pma import PoolingByMultiHeadAttention -@serializable +@serializable("bayesflow.networks") class SetTransformer(SummaryNetwork): """Implements the set transformer architecture from [1] which ultimately represents a learnable permutation-invariant function. Designed to naturally model interactions in diff --git a/bayesflow/networks/transformers/time_series_transformer.py b/bayesflow/networks/transformers/time_series_transformer.py index 16feca444..007ae8b74 100644 --- a/bayesflow/networks/transformers/time_series_transformer.py +++ b/bayesflow/networks/transformers/time_series_transformer.py @@ -10,7 +10,7 @@ from .mab import MultiHeadAttentionBlock -@serializable +@serializable("bayesflow.networks") class TimeSeriesTransformer(SummaryNetwork): def __init__( self, diff --git a/bayesflow/scores/mean_score.py b/bayesflow/scores/mean_score.py index 553a7c3af..45479c623 100644 --- a/bayesflow/scores/mean_score.py +++ b/bayesflow/scores/mean_score.py @@ -1,5 +1,4 @@ -from keras.saving import register_keras_serializable as serializable - +from bayesflow.utils.serialization import serializable from .normed_difference_score import NormedDifferenceScore diff --git a/bayesflow/scores/median_score.py b/bayesflow/scores/median_score.py index 10c8809c3..8f36ee4e7 100644 --- a/bayesflow/scores/median_score.py +++ b/bayesflow/scores/median_score.py @@ -1,5 +1,4 @@ -from keras.saving import register_keras_serializable as serializable - +from bayesflow.utils.serialization import serializable from .normed_difference_score import NormedDifferenceScore diff --git a/bayesflow/scores/multivariate_normal_score.py b/bayesflow/scores/multivariate_normal_score.py index 84cfd4910..ba2cec0c3 100644 --- a/bayesflow/scores/multivariate_normal_score.py +++ b/bayesflow/scores/multivariate_normal_score.py @@ -1,10 +1,10 @@ import math import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from bayesflow.links import PositiveDefinite +from bayesflow.utils.serialization import serializable from .parametric_distribution_score import ParametricDistributionScore diff --git a/bayesflow/scores/normed_difference_score.py b/bayesflow/scores/normed_difference_score.py index eb2795927..da6283f75 100644 --- a/bayesflow/scores/normed_difference_score.py +++ b/bayesflow/scores/normed_difference_score.py @@ -1,8 +1,8 @@ import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from bayesflow.utils import weighted_mean +from bayesflow.utils.serialization import serializable from .scoring_rule import ScoringRule diff --git a/bayesflow/scores/parametric_distribution_score.py b/bayesflow/scores/parametric_distribution_score.py index 3ead3271f..fd900cdbd 100644 --- a/bayesflow/scores/parametric_distribution_score.py +++ b/bayesflow/scores/parametric_distribution_score.py @@ -1,7 +1,6 @@ -from keras.saving import register_keras_serializable as serializable - from bayesflow.types import Tensor from bayesflow.utils import weighted_mean +from bayesflow.utils.serialization import serializable from .scoring_rule import ScoringRule diff --git a/bayesflow/scores/quantile_score.py b/bayesflow/scores/quantile_score.py index b05a35fc5..f965b4901 100644 --- a/bayesflow/scores/quantile_score.py +++ b/bayesflow/scores/quantile_score.py @@ -1,10 +1,10 @@ from typing import Sequence import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from bayesflow.utils import logging, weighted_mean +from bayesflow.utils.serialization import serializable from bayesflow.links import OrderedQuantiles from .scoring_rule import ScoringRule diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index bb55aee41..fffc7899a 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -96,28 +96,47 @@ def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs): @allow_args -def serializable(cls, package: str | None = None, name: str | None = None): - """Register class as Keras serialize. +def serializable(cls, package: str, name: str | None = None, disable_module_check: bool = False): + """Register class as Keras serializable. - Wrapper function around `keras.saving.register_keras_serializable` to automatically - set the `package` and `name` arguments. + Wrapper function around `keras.saving.register_keras_serializable` to automatically check consistency + of the supplied `package` argument with the module a class resides in. The `package` name should generally + be the module the class resides in, truncated at depth two. Valid examples would be "bayesflow.networks" + or "bayesflow.adapters". The check can be disabled if necessary by setting `disable_module_check` to True. + This should only be done in exceptional cases, and accompanied by a comment why it is necessary for a given + class. Parameters ---------- cls : type The class to register. - package : str, optional + package : str `package` argument forwarded to `keras.saving.register_keras_serializable`. - If None is provided, the package is automatically inferred using the __name__ - attribute of the module the class resides in. + Should generally correspond to the module of the class, truncated at depth two (e.g., "bayesflow.networks"). name : str, optional `name` argument forwarded to `keras.saving.register_keras_serializable`. If None is provided, the classe's __name__ attribute is used. + disable_module_check : bool, optional + Disable check that the provided `package` is consistent with the location of the class within the library. + + Raises + ------ + ValueError + If the supplied `package` does not correspond to the module of the class, truncated at depth two, and + `disable_module_check` is False. """ - if package is None: + if not disable_module_check: frame = sys._getframe(2) g = frame.f_globals - package = g.get("__name__", "bayesflow") + module_name = g.get("__name__", "bayesflow") + auto_package = ".".join(module_name.split(".")[:2]) + if package != auto_package: + raise ValueError( + "'package' should be the first two levels of the module the class resides in (e.g., bayesflow.networks)" + f'. In this case it should be \'package="{auto_package}"\' (was "{package}"). If this is not possible' + " (e.g., because a class was moved to a different module, and serializability should be preserved)," + " please set 'disable_module_check=True' and add a comment why it is necessary for this class." + ) if name is None: name = copy(cls.__name__) diff --git a/bayesflow/wrappers/mamba/mamba.py b/bayesflow/wrappers/mamba/mamba.py index d06508790..b328ede98 100644 --- a/bayesflow/wrappers/mamba/mamba.py +++ b/bayesflow/wrappers/mamba/mamba.py @@ -9,7 +9,7 @@ from .mamba_block import MambaBlock -@serializable +@serializable("bayesflow.wrappers") class Mamba(SummaryNetwork): """ Wraps a sequence of Mamba modules using the simple Mamba module from: diff --git a/bayesflow/wrappers/mamba/mamba_block.py b/bayesflow/wrappers/mamba/mamba_block.py index b8ba36d2e..bd15ecc29 100644 --- a/bayesflow/wrappers/mamba/mamba_block.py +++ b/bayesflow/wrappers/mamba/mamba_block.py @@ -6,7 +6,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.wrappers") class MambaBlock(keras.Layer): """ Wraps the original Mamba module from, with added functionality for bidirectional processing: diff --git a/tests/test_utils/test_serialize_deserialize.py b/tests/test_utils/test_serialize_deserialize.py index 6c3bc3983..a9888ecc5 100644 --- a/tests/test_utils/test_serialize_deserialize.py +++ b/tests/test_utils/test_serialize_deserialize.py @@ -3,7 +3,7 @@ from bayesflow.utils.serialization import deserialize, serializable, serialize -@serializable +@serializable("test", disable_module_check=True) class Foo: @classmethod def from_config(cls, config, custom_objects=None): @@ -13,7 +13,7 @@ def get_config(self): return {} -@serializable +@serializable("test", disable_module_check=True) class Bar: @classmethod def from_config(cls, config, custom_objects=None): diff --git a/tests/test_workflows/conftest.py b/tests/test_workflows/conftest.py index c98e543e9..e66b6efe4 100644 --- a/tests/test_workflows/conftest.py +++ b/tests/test_workflows/conftest.py @@ -40,7 +40,7 @@ def summary_network(request): elif request.param == "custom": from bayesflow.networks import SummaryNetwork - @serializable + @serializable("test", disable_module_check=True) class Custom(SummaryNetwork): def __init__(self, **kwargs): super().__init__(**kwargs) From f283ca0a36b5985a646a023983289e992b3fcdb4 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 2 May 2025 17:43:45 +0000 Subject: [PATCH 2/6] update serilization policy in dev docs --- docsrc/source/development/serialization.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docsrc/source/development/serialization.md b/docsrc/source/development/serialization.md index 15c812686..ec8988454 100644 --- a/docsrc/source/development/serialization.md +++ b/docsrc/source/development/serialization.md @@ -17,10 +17,19 @@ As we want/need to pass them in some places, we have to resort to some custom be BayesFlows serialization utilities can be found in the {py:mod}`~bayesflow.utils.serialization` module. We mainly provide three convenience functions: -- The {py:func}`~bayesflow.utils.serialization.serializable` decorator wraps the `keras.saving.register_keras_serializable` function to provide automatic `package` and `name` arguments. +- The {py:func}`~bayesflow.utils.serialization.serializable` decorator wraps the `keras.saving.register_keras_serializable` function to ensure consistent naming of the `package` argument within the library. - The {py:func}`~bayesflow.utils.serialization.serialize` function, which adds support for serializing classes. - Its counterpart {py:func}`~bayesflow.utils.serialization.deserialize`, adds support to deserialize classes. ## Usage To use the adapted serialization functions, you have to use them in the `get_config` and `from_config` method. Please refer to existing classes in the library for usage examples. + +### The `serializable` Decorator + +To make serialization as little confusing as possible, as well as providing stability even when moving classes around, we provide the `package` argument explicitly for each class. +The naming should respect the following naming scheme: Take the module the class resides in (for example, `bayesflow.adapters.transforms.standardize`), and truncate the path to depth two (`bayesflow.adapters`). +In cases where this convention cannot be followed, set `disable_module_check` to `True`, and describe why a different name was necessary. +Changing `package` breaks backwards-compatibility for serialization, so it should be avoided whenever possible. +If you move a class to a different module (without changing the class itself), keep the `package` and set `disable_module_check` to `True`. +This may later be adapted in a release that breaks backward compatiblity anyways. From f796ab79bcc48e75f0a018d8c57e3251cd5a3f83 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 2 May 2025 17:55:37 +0000 Subject: [PATCH 3/6] README: add not regarding breaking changes until 2.1 release --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 5d2dc61f0..f2957dc9a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,12 @@ It provides users and researchers with: BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference fueled by continuous progress in generative AI and Bayesian inference. +> [!IMPORTANT] +> As the 2.0 version introduced many new features, we still have to make breaking changes from time to time. +> This especially concerns **saving and loading** of models. We aim to stabilize this from the 2.1 release onwards. +> Until then, consider pinning your BayesFlow 2.0 installation to an exact version, or re-training after an update +> for less costly models. + ## Important Note for Existing Users You are currently looking at BayesFlow 2.0+, which is a complete rewrite of the library. From 2ae71ea5265857d81f49e9390a36bfc98b58f33f Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 3 May 2025 13:39:49 -0400 Subject: [PATCH 4/6] standardize use of serializable decorator --- bayesflow/adapters/transforms/serializable_custom_transform.py | 2 +- bayesflow/experimental/cif/cif.py | 2 +- bayesflow/experimental/cif/conditional_gaussian.py | 2 +- bayesflow/links/ordered.py | 2 +- bayesflow/links/ordered_quantiles.py | 2 +- bayesflow/links/positive_definite.py | 2 +- bayesflow/networks/point_inference_network.py | 2 +- bayesflow/scores/mean_score.py | 2 +- bayesflow/scores/median_score.py | 2 +- bayesflow/scores/multivariate_normal_score.py | 2 +- bayesflow/scores/normed_difference_score.py | 2 +- bayesflow/scores/parametric_distribution_score.py | 2 +- bayesflow/scores/quantile_score.py | 2 +- bayesflow/scores/scoring_rule.py | 2 +- 14 files changed, 14 insertions(+), 14 deletions(-) diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py index fbfc4615b..a78c69177 100644 --- a/bayesflow/adapters/transforms/serializable_custom_transform.py +++ b/bayesflow/adapters/transforms/serializable_custom_transform.py @@ -13,7 +13,7 @@ import inspect -@serializable(package="bayesflow.adapters") +@serializable("bayesflow.adapters") class SerializableCustomTransform(ElementwiseTransform): """ Transforms a parameter using a pair of registered serializable forward and inverse functions. diff --git a/bayesflow/experimental/cif/cif.py b/bayesflow/experimental/cif/cif.py index c77e79366..bd776f93e 100644 --- a/bayesflow/experimental/cif/cif.py +++ b/bayesflow/experimental/cif/cif.py @@ -10,7 +10,7 @@ # disable module check, use potential module after moving from experimental -@serializable(package="bayesflow.networks", disable_module_check=True) +@serializable("bayesflow.networks", disable_module_check=True) class CIF(InferenceNetwork): """Implements a continuously indexed flow (CIF) with a `CouplingFlow` bijection and `ConditionalGaussian` distributions p and q. Improves on diff --git a/bayesflow/experimental/cif/conditional_gaussian.py b/bayesflow/experimental/cif/conditional_gaussian.py index e1a5ac7f7..ebba47a2e 100644 --- a/bayesflow/experimental/cif/conditional_gaussian.py +++ b/bayesflow/experimental/cif/conditional_gaussian.py @@ -8,7 +8,7 @@ # disable module check, use potential module after moving from experimental -@serializable(package="bayesflow.networks", disable_module_check=True) +@serializable("bayesflow.networks", disable_module_check=True) class ConditionalGaussian(keras.Layer): """Implements a conditional gaussian distribution with neural networks for the means and standard deviations respectively. Bulit in reference to [1]. diff --git a/bayesflow/links/ordered.py b/bayesflow/links/ordered.py index 315bcdce4..25caf5350 100644 --- a/bayesflow/links/ordered.py +++ b/bayesflow/links/ordered.py @@ -5,7 +5,7 @@ from bayesflow.utils.serialization import serializable -@serializable(package="bayesflow.links") +@serializable("bayesflow.links") class Ordered(keras.Layer): """Activation function to link to a tensor which is monotonously increasing along a specified axis.""" diff --git a/bayesflow/links/ordered_quantiles.py b/bayesflow/links/ordered_quantiles.py index 09e565b5d..81b2c0cc7 100644 --- a/bayesflow/links/ordered_quantiles.py +++ b/bayesflow/links/ordered_quantiles.py @@ -8,7 +8,7 @@ from .ordered import Ordered -@serializable(package="bayesflow.links") +@serializable("bayesflow.links") class OrderedQuantiles(Ordered): """Activation function to link to monotonously increasing quantile estimates.""" diff --git a/bayesflow/links/positive_definite.py b/bayesflow/links/positive_definite.py index 112b37305..909ac2792 100644 --- a/bayesflow/links/positive_definite.py +++ b/bayesflow/links/positive_definite.py @@ -5,7 +5,7 @@ from bayesflow.utils.serialization import serializable -@serializable(package="bayesflow.links") +@serializable("bayesflow.links") class PositiveDefinite(keras.Layer): """Activation function to link from flat elements of a lower triangular matrix to a positive definite matrix.""" diff --git a/bayesflow/networks/point_inference_network.py b/bayesflow/networks/point_inference_network.py index 6e823ba14..402632355 100644 --- a/bayesflow/networks/point_inference_network.py +++ b/bayesflow/networks/point_inference_network.py @@ -7,7 +7,7 @@ from bayesflow.utils.decorators import allow_batch_size -@serializable(package="bayesflow.networks") +@serializable("bayesflow.networks") class PointInferenceNetwork(keras.Layer): """Implements point estimation for user specified scoring rules by a shared feed forward architecture with separate heads for each scoring rule. diff --git a/bayesflow/scores/mean_score.py b/bayesflow/scores/mean_score.py index 45479c623..0c7f200b2 100644 --- a/bayesflow/scores/mean_score.py +++ b/bayesflow/scores/mean_score.py @@ -2,7 +2,7 @@ from .normed_difference_score import NormedDifferenceScore -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class MeanScore(NormedDifferenceScore): r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |^2` diff --git a/bayesflow/scores/median_score.py b/bayesflow/scores/median_score.py index 8f36ee4e7..385c47436 100644 --- a/bayesflow/scores/median_score.py +++ b/bayesflow/scores/median_score.py @@ -2,7 +2,7 @@ from .normed_difference_score import NormedDifferenceScore -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class MedianScore(NormedDifferenceScore): r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |` diff --git a/bayesflow/scores/multivariate_normal_score.py b/bayesflow/scores/multivariate_normal_score.py index ba2cec0c3..7c745919c 100644 --- a/bayesflow/scores/multivariate_normal_score.py +++ b/bayesflow/scores/multivariate_normal_score.py @@ -9,7 +9,7 @@ from .parametric_distribution_score import ParametricDistributionScore -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class MultivariateNormalScore(ParametricDistributionScore): r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))` diff --git a/bayesflow/scores/normed_difference_score.py b/bayesflow/scores/normed_difference_score.py index da6283f75..d33bc128f 100644 --- a/bayesflow/scores/normed_difference_score.py +++ b/bayesflow/scores/normed_difference_score.py @@ -7,7 +7,7 @@ from .scoring_rule import ScoringRule -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class NormedDifferenceScore(ScoringRule): r""":math:`S(\hat \theta, \theta; k) = | \hat \theta - \theta |^k` diff --git a/bayesflow/scores/parametric_distribution_score.py b/bayesflow/scores/parametric_distribution_score.py index fd900cdbd..91df32d48 100644 --- a/bayesflow/scores/parametric_distribution_score.py +++ b/bayesflow/scores/parametric_distribution_score.py @@ -5,7 +5,7 @@ from .scoring_rule import ScoringRule -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class ParametricDistributionScore(ScoringRule): r""":math:`S(\hat p_\phi, \theta; k) = -\log(\hat p_\phi(\theta))` diff --git a/bayesflow/scores/quantile_score.py b/bayesflow/scores/quantile_score.py index f965b4901..7ba021340 100644 --- a/bayesflow/scores/quantile_score.py +++ b/bayesflow/scores/quantile_score.py @@ -10,7 +10,7 @@ from .scoring_rule import ScoringRule -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class QuantileScore(ScoringRule): r""":math:`S(\hat \theta_i, \theta; \tau_i) = (\hat \theta_i - \theta)(\mathbf{1}_{\hat \theta - \theta > 0} - \tau_i)` diff --git a/bayesflow/scores/scoring_rule.py b/bayesflow/scores/scoring_rule.py index 0144de458..6dee0afec 100644 --- a/bayesflow/scores/scoring_rule.py +++ b/bayesflow/scores/scoring_rule.py @@ -7,7 +7,7 @@ from bayesflow.utils.serialization import deserialize, serializable, serialize -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class ScoringRule: """Base class for scoring rules. From a04683d7498d8803d1396d4a7df70ff8dbe88044 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 4 May 2025 07:19:53 +0000 Subject: [PATCH 5/6] [no ci] change (de)serialize to new pipeline in transform --- .../adapters/transforms/serializable_custom_transform.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py index fbfc4615b..5e3e1e1c9 100644 --- a/bayesflow/adapters/transforms/serializable_custom_transform.py +++ b/bayesflow/adapters/transforms/serializable_custom_transform.py @@ -1,13 +1,11 @@ from collections.abc import Callable import numpy as np from keras.saving import ( - deserialize_keras_object as deserialize, - serialize_keras_object as serialize, get_registered_name, get_registered_object, ) -from bayesflow.utils.serialization import serializable +from bayesflow.utils.serialization import deserialize, serializable, serialize from .elementwise_transform import ElementwiseTransform from ...utils import filter_kwargs import inspect From 18afaf17cdc13dac849ca0f8897f4606fecbe720 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 4 May 2025 07:50:51 +0000 Subject: [PATCH 6/6] serialization check: exempt classes not in bayesflow module This should ensure that users that try to use our decorator with external classes do not encounter the error. Possible edge case: they also name their module "bayesflow". --- bayesflow/utils/serialization.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index fffc7899a..5be0e0e1d 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -123,14 +123,16 @@ def serializable(cls, package: str, name: str | None = None, disable_module_chec ------ ValueError If the supplied `package` does not correspond to the module of the class, truncated at depth two, and - `disable_module_check` is False. + `disable_module_check` is False. No error is thrown when a class is not part of the bayesflow module. """ if not disable_module_check: frame = sys._getframe(2) g = frame.f_globals - module_name = g.get("__name__", "bayesflow") + module_name = g.get("__name__", "") + # only apply this check if the class is inside the bayesflow module + is_bayesflow = module_name.split(".")[0] == "bayesflow" auto_package = ".".join(module_name.split(".")[:2]) - if package != auto_package: + if is_bayesflow and package != auto_package: raise ValueError( "'package' should be the first two levels of the module the class resides in (e.g., bayesflow.networks)" f'. In this case it should be \'package="{auto_package}"\' (was "{package}"). If this is not possible'