From 045df8ec715512615385c0dbd3fc8f6c29e06275 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Tue, 25 Mar 2025 19:22:59 -0400 Subject: [PATCH 1/6] Simplify warning --- bayesflow/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/__init__.py b/bayesflow/__init__.py index 6f098ea86..c0e829e59 100644 --- a/bayesflow/__init__.py +++ b/bayesflow/__init__.py @@ -40,9 +40,9 @@ def setup(): torch.autograd.set_grad_enabled(False) logging.warning( - "When using torch backend, we need to disable autograd by default to avoid excessive memory usage. Use\n" - "with torch.enable_grad():\n" - "in contexts where you need gradients (e.g. custom training loops)." + "Autograd is disabled by default to avoid excessive memory usage. " + "If you need gradients (e.g., custom training loops), use\n" + "with torch.enable_grad():" ) From 2ba93d43f7baa3a68cd78cba6a55c866f1442ad3 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Thu, 10 Apr 2025 07:24:27 +0000 Subject: [PATCH 2/6] Add class for custom transforms to adapter. This commit reintroduces the features that were present in `LambdaTransform`, but only allowing registered functions. While being stricter, that allows for closer scaffolding and raising errors early on, so that users cannot provide functions that will not be (de)serializable later on. As there are a few failure modes, the focus is on providing detailed error messages to enable users to solve problems without external help. --- bayesflow/adapters/adapter.py | 85 +++++++- bayesflow/adapters/transforms/__init__.py | 1 + .../serializable_custom_transform.py | 184 ++++++++++++++++++ tests/test_adapters/conftest.py | 8 + tests/test_adapters/test_adapters.py | 69 +++++++ 5 files changed, 346 insertions(+), 1 deletion(-) create mode 100644 bayesflow/adapters/transforms/serializable_custom_transform.py diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index d0711062c..f3eb7386e 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -1,4 +1,4 @@ -from collections.abc import MutableSequence, Sequence, Mapping +from collections.abc import Callable, MutableSequence, Sequence, Mapping import numpy as np @@ -24,6 +24,7 @@ NumpyTransform, OneHot, Rename, + SerializableCustomTransform, Sqrt, Standardize, ToArray, @@ -283,6 +284,88 @@ def apply( self.transforms.append(transform) return self + def apply_serializable( + self, + include: str | Sequence[str] = None, + *, + serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray], + serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray], + predicate: Predicate = None, + exclude: str | Sequence[str] = None, + **kwargs, + ): + """Append a :py:class:`~transforms.SerializableCustomTransform` to the adapter. + + Parameters + ---------- + serializable_forward_fn : function, no lambda + Registered serializable function to transform the data in the forward pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + serializable_inverse_fn : function, no lambda + Registered serializable function to transform the data in the inverse pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + predicate : Predicate, optional + Function that indicates which variables should be transformed. + include : str or Sequence of str, optional + Names of variables to include in the transform. + exclude : str or Sequence of str, optional + Names of variables to exclude from the transform. + **kwargs : dict + Additional keyword arguments passed to the transform. + + Raises + ------ + ValueError + When the provided functions are not registered serializable functions. + + Notes + ----- + Important: The forward and inverse functions have to be registered with Keras. + To do so, use the `@keras.saving.register_keras_serializable` decorator. + They must also be registered (and identical) when loading the adapter + at a later point in time. + + Examples + -------- + + The example below shows how to use the + `keras.saving.register_keras_serializable` decorator to + register functions with Keras. Note that for this simple + example, one usually would use the simpler :py:meth:`apply` + method. + + >>> import keras + >>> + >>> @keras.saving.register_keras_serializable("custom") + >>> def forward_fn(x): + >>> return x**2 + >>> + >>> @keras.saving.register_keras_serializable("custom") + >>> def inverse_fn(x): + >>> return x**0.5 + >>> + >>> adapter = bf.Adapter().apply_serializable( + >>> "x", + >>> serializable_forward_fn=forward_fn, + >>> serializable_inverse_fn=inverse_fn, + >>> ) + """ + transform = FilterTransform( + transform_constructor=SerializableCustomTransform, + predicate=predicate, + include=include, + exclude=exclude, + serializable_forward_fn=serializable_forward_fn, + serializable_inverse_fn=serializable_inverse_fn, + **kwargs, + ) + self.transforms.append(transform) + return self + def as_set(self, keys: str | Sequence[str]): """Append an :py:class:`~transforms.AsSet` transform to the adapter. diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index b3e95a494..81e9f665f 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -15,6 +15,7 @@ from .one_hot import OneHot from .rename import Rename from .scale import Scale +from .serializable_custom_transform import SerializableCustomTransform from .shift import Shift from .sqrt import Sqrt from .standardize import Standardize diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py new file mode 100644 index 000000000..6337d1f12 --- /dev/null +++ b/bayesflow/adapters/transforms/serializable_custom_transform.py @@ -0,0 +1,184 @@ +from collections.abc import Callable +import numpy as np +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, + get_registered_name, + get_registered_object, +) +from .elementwise_transform import ElementwiseTransform +from ...utils import filter_kwargs +import inspect + + +@serializable(package="bayesflow.adapters") +class SerializableCustomTransform(ElementwiseTransform): + """ + Transforms a parameter using a pair of registered serializable forward and inverse functions. + + Parameters + ---------- + serializable_forward_fn : function, no lambda + Registered serializable function to transform the data in the forward pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + serializable_inverse_fn : function, no lambda + Function to transform the data in the inverse pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + + Raises + ------ + ValueError + When the provided functions are not registered serializable functions. + + Notes + ----- + Important: The forward and inverse functions have to be registered with Keras. + To do so, use the `@keras.saving.register_keras_serializable` decorator. + They must also be registered (and identical) when loading the adapter + at a later point in time. + + """ + + def __init__( + self, + *, + serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray], + serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray], + ): + super().__init__() + + self._check_serializable(serializable_forward_fn, label="serializable_forward_fn") + self._check_serializable(serializable_inverse_fn, label="serializable_inverse_fn") + self._forward = serializable_forward_fn + self._inverse = serializable_inverse_fn + + @classmethod + def _check_serializable(cls, function, label=""): + GENERAL_EXAMPLE_CODE = f"""The example code below shows the structure of a correctly decorated function: + +``` +import keras + +@keras.saving.register_keras_serializable('custom') +def my_{label}(...): + [your code goes here...] +``` +""" + if function is None: + raise TypeError( + f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}" + ) + registered_name = get_registered_name(function) + # check if function is a lambda function + if registered_name == "": + raise ValueError( + f"The provided function for '{label}' is a lambda function, " + "which cannot be serialized. " + "Please provide a registered serializable function by using the " + "@keras.saving.register_keras_serializable decorator." + f"\n{GENERAL_EXAMPLE_CODE}" + ) + if inspect.ismethod(function): + raise ValueError( + f"The provided value for '{label}' is a method, not a function. " + "Methods cannot be serialized separately from their classes. " + "Please provide a registered serializable function instead by " + "moving the functionality to a function (i.e., outside of the class) and " + "using the @keras.saving.register_keras_serializable decorator." + f"\n{GENERAL_EXAMPLE_CODE}" + ) + registered_object_for_name = get_registered_object(registered_name) + if registered_object_for_name is None: + try: + source_max_lines = 5 + function_source_code = inspect.getsource(function).split("\n") + if len(function_source_code) > source_max_lines: + function_source_code = function_source_code[:source_max_lines] + [" [...]"] + + example_code = "For your provided function, this would look like this:\n\n" + example_code += "\n".join( + ["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"] + + function_source_code + + ["```"] + ) + except OSError: + example_code = GENERAL_EXAMPLE_CODE + raise ValueError( + f"The provided function for '{label}' is not registered with Keras.\n" + "Please register the function using the " + "@keras.saving.register_keras_serializable decorator.\n" + f"{example_code}" + ) + if registered_object_for_name is not function: + raise ValueError( + f"The provided function for '{label}' does not match the function " + f"registered under its name '{registered_name}'. " + f"(registered function: {registered_object_for_name}, provided function: {function}). " + ) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform": + if get_registered_object(config["forward"]["config"], custom_objects) is None: + provided_function_msg = "" + if config["_forward_source_code"]: + provided_function_msg = ( + f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```" + ) + raise TypeError( + "\n\nPLEASE READ HERE:\n" + "-----------------\n" + "The forward function that was provided as `serializable_forward_fn` " + "is not registered with Keras, making deserialization impossible. " + f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original " + "function before loading your model." + f"{provided_function_msg}" + ) + if get_registered_object(config["inverse"]["config"], custom_objects) is None: + provided_function_msg = "" + if config["_inverse_source_code"]: + provided_function_msg = ( + f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```" + ) + raise TypeError( + "\n\nPLEASE READ HERE:\n" + "-----------------\n" + "The inverse function that was provided as `serializable_inverse_fn` " + "is not registered with Keras, making deserialization impossible. " + f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original " + "function before loading your model." + f"{provided_function_msg}" + ) + forward = deserialize(config["forward"], custom_objects) + inverse = deserialize(config["inverse"], custom_objects) + return cls( + serializable_forward_fn=forward, + serializable_inverse_fn=inverse, + ) + + def get_config(self) -> dict: + forward_source_code = inverse_source_code = None + try: + forward_source_code = inspect.getsource(self._forward) + inverse_source_code = inspect.getsource(self._inverse) + except OSError: + pass + return { + "forward": serialize(self._forward), + "inverse": serialize(self._inverse), + "_forward_source_code": forward_source_code, + "_inverse_source_code": inverse_source_code, + } + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + # filter kwargs so that other transform args like batch_size, strict, ... are not passed through + kwargs = filter_kwargs(kwargs, self._forward) + return self._forward(data, **kwargs) + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + kwargs = filter_kwargs(kwargs, self._inverse) + return self._inverse(data, **kwargs) diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index abd5797bd..b8365e94a 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -5,6 +5,11 @@ @pytest.fixture() def adapter(): from bayesflow.adapters import Adapter + import keras + + @keras.saving.register_keras_serializable("custom") + def serializable_fn(x): + return x d = ( Adapter() @@ -20,6 +25,9 @@ def adapter(): .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log") .apply(include="p2", forward="log1p") + .apply_serializable( + include="x", serializable_forward_fn=serializable_fn, serializable_inverse_fn=serializable_fn + ) .scale("x", by=[-1, 2]) .shift("x", by=2) .standardize(exclude=["t1", "t2", "o1"]) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 3dea0baf4..69edb6e34 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -3,6 +3,7 @@ serialize_keras_object as serialize, ) import numpy as np +import pytest def test_cycle_consistency(adapter, random_data): @@ -110,3 +111,71 @@ def test_simple_transforms(random_data): assert np.allclose(inverse["t1"], random_data["t1"]) assert np.allclose(inverse["p1"], random_data["p1"]) + + +def test_custom_transform(): + # test that transform raises errors in all relevant cases + import keras + from bayesflow.adapters.transforms import SerializableCustomTransform + from copy import deepcopy + + class A: + @classmethod + def fn(cls, x): + return x + + def not_registered_fn(x): + return x + + @keras.saving.register_keras_serializable("custom") + def registered_fn(x): + return x + + @keras.saving.register_keras_serializable("custom") + def registered_but_changed(x): + return x + + def registered_but_changed(x): # noqa: F811 + return 2 * x + + # method instead of function provided + with pytest.raises(ValueError): + SerializableCustomTransform(serializable_forward_fn=A.fn, serializable_inverse_fn=registered_fn) + SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=A.fn) + + # lambda function provided + with pytest.raises(ValueError): + SerializableCustomTransform(serializable_forward_fn=lambda x: x, serializable_inverse_fn=registered_fn) + SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=lambda x: x) + + # unregistered function provided + with pytest.raises(ValueError): + SerializableCustomTransform(serializable_forward_fn=not_registered_fn, serializable_inverse_fn=registered_fn) + SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=not_registered_fn) + + # function does not match registered function + with pytest.raises(ValueError): + SerializableCustomTransform( + serializable_forward_fn=registered_but_changed, serializable_inverse_fn=registered_fn + ) + SerializableCustomTransform( + serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_but_changed + ) + + transform = SerializableCustomTransform( + serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_fn + ) + serialized_transform = keras.saving.serialize_keras_object(transform) + keras.saving.deserialize_keras_object(serialized_transform) + + # modify name of the forward function so that it cannot be found + corrupt_serialized_transform = deepcopy(serialized_transform) + corrupt_serialized_transform["config"]["forward"]["config"] = "nonexistent" + with pytest.raises(TypeError): + keras.saving.deserialize_keras_object(corrupt_serialized_transform) + + # modify name of the inverse transform so that it cannot be found + corrupt_serialized_transform = deepcopy(serialized_transform) + corrupt_serialized_transform["config"]["inverse"]["config"] = "nonexistent" + with pytest.raises(TypeError): + keras.saving.deserialize_keras_object(corrupt_serialized_transform) From 8fdfde9bc435995939e422c0b067e091c65dfec4 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 11 Apr 2025 16:17:59 +0000 Subject: [PATCH 3/6] custom transform: less verbose naming, fix tests --- bayesflow/adapters/adapter.py | 16 +++---- .../serializable_custom_transform.py | 43 +++++++++---------- tests/test_adapters/conftest.py | 4 +- tests/test_adapters/test_adapters.py | 30 ++++++------- 4 files changed, 44 insertions(+), 49 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index f3eb7386e..6476ffcb4 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -288,8 +288,8 @@ def apply_serializable( self, include: str | Sequence[str] = None, *, - serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray], - serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray], + forward: Callable[[np.ndarray, ...], np.ndarray], + inverse: Callable[[np.ndarray, ...], np.ndarray], predicate: Predicate = None, exclude: str | Sequence[str] = None, **kwargs, @@ -298,12 +298,12 @@ def apply_serializable( Parameters ---------- - serializable_forward_fn : function, no lambda + forward : function, no lambda Registered serializable function to transform the data in the forward pass. For the adapter to be serializable, this function has to be serializable as well (see Notes). Therefore, only proper functions and no lambda functions can be used here. - serializable_inverse_fn : function, no lambda + inverse : function, no lambda Registered serializable function to transform the data in the inverse pass. For the adapter to be serializable, this function has to be serializable as well (see Notes). Therefore, only proper functions and no lambda @@ -350,8 +350,8 @@ def apply_serializable( >>> >>> adapter = bf.Adapter().apply_serializable( >>> "x", - >>> serializable_forward_fn=forward_fn, - >>> serializable_inverse_fn=inverse_fn, + >>> forward=forward_fn, + >>> inverse=inverse_fn, >>> ) """ transform = FilterTransform( @@ -359,8 +359,8 @@ def apply_serializable( predicate=predicate, include=include, exclude=exclude, - serializable_forward_fn=serializable_forward_fn, - serializable_inverse_fn=serializable_inverse_fn, + forward=forward, + inverse=inverse, **kwargs, ) self.transforms.append(transform) diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py index 6337d1f12..75d588afd 100644 --- a/bayesflow/adapters/transforms/serializable_custom_transform.py +++ b/bayesflow/adapters/transforms/serializable_custom_transform.py @@ -19,12 +19,12 @@ class SerializableCustomTransform(ElementwiseTransform): Parameters ---------- - serializable_forward_fn : function, no lambda + forward : function, no lambda Registered serializable function to transform the data in the forward pass. For the adapter to be serializable, this function has to be serializable as well (see Notes). Therefore, only proper functions and no lambda functions can be used here. - serializable_inverse_fn : function, no lambda + inverse : function, no lambda Function to transform the data in the inverse pass. For the adapter to be serializable, this function has to be serializable as well (see Notes). Therefore, only proper functions and no lambda @@ -47,28 +47,27 @@ class SerializableCustomTransform(ElementwiseTransform): def __init__( self, *, - serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray], - serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray], + forward: Callable[[np.ndarray, ...], np.ndarray], + inverse: Callable[[np.ndarray, ...], np.ndarray], ): super().__init__() - self._check_serializable(serializable_forward_fn, label="serializable_forward_fn") - self._check_serializable(serializable_inverse_fn, label="serializable_inverse_fn") - self._forward = serializable_forward_fn - self._inverse = serializable_inverse_fn + self._check_serializable(forward, label="forward") + self._check_serializable(inverse, label="inverse") + self._forward = forward + self._inverse = inverse @classmethod def _check_serializable(cls, function, label=""): - GENERAL_EXAMPLE_CODE = f"""The example code below shows the structure of a correctly decorated function: - -``` -import keras - -@keras.saving.register_keras_serializable('custom') -def my_{label}(...): - [your code goes here...] -``` -""" + GENERAL_EXAMPLE_CODE = ( + "The example code below shows the structure of a correctly decorated function:\n\n" + "```\n" + "import keras\n\n" + "@keras.saving.register_keras_serializable('custom')\n" + f"def my_{label}(...):\n" + " [your code goes here...]\n" + "```\n" + ) if function is None: raise TypeError( f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}" @@ -132,7 +131,7 @@ def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTr raise TypeError( "\n\nPLEASE READ HERE:\n" "-----------------\n" - "The forward function that was provided as `serializable_forward_fn` " + "The forward function that was provided as `forward` " "is not registered with Keras, making deserialization impossible. " f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original " "function before loading your model." @@ -147,7 +146,7 @@ def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTr raise TypeError( "\n\nPLEASE READ HERE:\n" "-----------------\n" - "The inverse function that was provided as `serializable_inverse_fn` " + "The inverse function that was provided as `inverse` " "is not registered with Keras, making deserialization impossible. " f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original " "function before loading your model." @@ -156,8 +155,8 @@ def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTr forward = deserialize(config["forward"], custom_objects) inverse = deserialize(config["inverse"], custom_objects) return cls( - serializable_forward_fn=forward, - serializable_inverse_fn=inverse, + forward=forward, + inverse=inverse, ) def get_config(self) -> dict: diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index b8365e94a..d379d7cc4 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -25,9 +25,7 @@ def serializable_fn(x): .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log") .apply(include="p2", forward="log1p") - .apply_serializable( - include="x", serializable_forward_fn=serializable_fn, serializable_inverse_fn=serializable_fn - ) + .apply_serializable(include="x", forward=serializable_fn, inverse=serializable_fn) .scale("x", by=[-1, 2]) .shift("x", by=2) .standardize(exclude=["t1", "t2", "o1"]) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 69edb6e34..0d17c419f 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -140,31 +140,29 @@ def registered_but_changed(x): # noqa: F811 # method instead of function provided with pytest.raises(ValueError): - SerializableCustomTransform(serializable_forward_fn=A.fn, serializable_inverse_fn=registered_fn) - SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=A.fn) + SerializableCustomTransform(forward=A.fn, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=A.fn) # lambda function provided with pytest.raises(ValueError): - SerializableCustomTransform(serializable_forward_fn=lambda x: x, serializable_inverse_fn=registered_fn) - SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=lambda x: x) + SerializableCustomTransform(forward=lambda x: x, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=lambda x: x) # unregistered function provided with pytest.raises(ValueError): - SerializableCustomTransform(serializable_forward_fn=not_registered_fn, serializable_inverse_fn=registered_fn) - SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=not_registered_fn) + SerializableCustomTransform(forward=not_registered_fn, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=not_registered_fn) # function does not match registered function with pytest.raises(ValueError): - SerializableCustomTransform( - serializable_forward_fn=registered_but_changed, serializable_inverse_fn=registered_fn - ) - SerializableCustomTransform( - serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_but_changed - ) - - transform = SerializableCustomTransform( - serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_fn - ) + SerializableCustomTransform(forward=registered_but_changed, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=registered_but_changed) + + transform = SerializableCustomTransform(forward=registered_fn, inverse=registered_fn) serialized_transform = keras.saving.serialize_keras_object(transform) keras.saving.deserialize_keras_object(serialized_transform) From 1c6b90dfe2653905d26f1a2a550767fc28dc0f02 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 11 Apr 2025 17:08:27 +0000 Subject: [PATCH 4/6] Enable integration with fixed, non-uniform schedule. This enables specialized, manually specified schedules, as for example required in diffusion models for inference. Closes #402. --- bayesflow/utils/integrate.py | 64 +++++++++++++++++++++++++----- tests/test_utils/__init__.py | 0 tests/test_utils/conftest.py | 0 tests/test_utils/test_integrate.py | 11 +++++ 4 files changed, 64 insertions(+), 11 deletions(-) create mode 100644 tests/test_utils/__init__.py create mode 100644 tests/test_utils/conftest.py create mode 100644 tests/test_utils/test_integrate.py diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a932424f5..5e3b407ec 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -1,8 +1,9 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import partial import keras +import numpy as np from typing import Literal from bayesflow.types import Tensor @@ -233,21 +234,62 @@ def body(_state, _time, _step_size, _step): return state +def integrate_scheduled( + fn: Callable, + state: dict[str, ArrayLike], + steps: Tensor | np.ndarray, + method: str = "rk45", + **kwargs, +) -> dict[str, ArrayLike]: + match method: + case "euler": + step_fn = euler_step + case "rk45": + step_fn = rk45_step + case str() as name: + raise ValueError(f"Unknown integration method name: {name!r}") + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + + def body(_loop_var, _loop_state): + _time = steps[_loop_var] + step_size = steps[_loop_var + 1] - steps[_loop_var] + + _loop_state, _, _ = step_fn(_loop_state, _time, step_size) + return _loop_state + + state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + return state + + def integrate( fn: Callable, state: dict[str, ArrayLike], - start_time: ArrayLike, - stop_time: ArrayLike, + start_time: ArrayLike | None = None, + stop_time: ArrayLike | None = None, min_steps: int = 10, max_steps: int = 10_000, - steps: int | Literal["adaptive"] = 100, + steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, method: str = "euler", **kwargs, ) -> dict[str, ArrayLike]: - match steps: - case "adaptive" | "dynamic": - return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) - case int(): - return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs) - case _: - raise RuntimeError("Type or value of `steps` not understood.") + if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: + if start_time is None or stop_time is None: + raise ValueError( + "Please provide start_time and stop_time for the integration, was " + f"'start_time={start_time}', 'stop_time={stop_time}'." + ) + return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) + elif isinstance(steps, int): + if start_time is None or stop_time is None: + raise ValueError( + "Please provide start_time and stop_time for the integration, was " + f"'start_time={start_time}', 'stop_time={stop_time}'." + ) + return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs) + elif isinstance(steps, Sequence) or isinstance(steps, np.ndarray) or keras.ops.is_tensor(steps): + return integrate_scheduled(fn, state, steps, method, **kwargs) + else: + raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_utils/conftest.py b/tests/test_utils/conftest.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py new file mode 100644 index 000000000..8147d0589 --- /dev/null +++ b/tests/test_utils/test_integrate.py @@ -0,0 +1,11 @@ +def test_scheduled_integration(): + import keras + from bayesflow.utils import integrate + + def fn(t, x): + return {"x": t**2} + + steps = keras.ops.convert_to_tensor([0.0, 0.5, 1.0]) + approximate_result = 0.0 + 0.5**2 * 0.5 + result = integrate(fn, {"x": 0.0}, steps=steps)["x"] + assert result == approximate_result From 18df01c85dc826c61b80e6ae94256478cb0e3adb Mon Sep 17 00:00:00 2001 From: Stefan Radev Date: Fri, 11 Apr 2025 19:10:02 -0400 Subject: [PATCH 5/6] Uncomment check for torch backend. --- bayesflow/wrappers/mamba/mamba_block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/wrappers/mamba/mamba_block.py b/bayesflow/wrappers/mamba/mamba_block.py index 17617e33d..f629807a0 100644 --- a/bayesflow/wrappers/mamba/mamba_block.py +++ b/bayesflow/wrappers/mamba/mamba_block.py @@ -55,8 +55,8 @@ def __init__( super().__init__(**keras_kwargs(kwargs)) - # if keras.backend.backend() != "torch": - # raise RuntimeError("Mamba is only available using torch backend.") + if keras.backend.backend() != "torch": + raise RuntimeError("Mamba is only available using torch backend.") try: from mamba_ssm import Mamba From 2bf0b5336a7013cc1aff9ec4a5fe80ad9768b7b8 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Sat, 12 Apr 2025 17:23:23 -0400 Subject: [PATCH 6/6] Remove old usage of aggregate and use weighted_mean everywhere --- .../free_form_flow/free_form_flow.py | 3 ++- .../consistency_models/consistency_model.py | 10 +++++-- .../networks/coupling_flow/coupling_flow.py | 4 +-- .../networks/flow_matching/flow_matching.py | 4 +-- bayesflow/scores/normed_difference_score.py | 3 ++- .../scores/parametric_distribution_score.py | 3 ++- bayesflow/scores/quantile_score.py | 6 ++--- bayesflow/scores/scoring_rule.py | 27 ------------------- bayesflow/utils/__init__.py | 2 +- bayesflow/utils/tensor_utils.py | 2 +- 10 files changed, 23 insertions(+), 41 deletions(-) diff --git a/bayesflow/experimental/free_form_flow/free_form_flow.py b/bayesflow/experimental/free_form_flow/free_form_flow.py index a4ad3c8be..b2dd911b7 100644 --- a/bayesflow/experimental/free_form_flow/free_form_flow.py +++ b/bayesflow/experimental/free_form_flow/free_form_flow.py @@ -12,6 +12,7 @@ vjp, serialize_value_or_type, deserialize_value_or_type, + weighted_mean, ) from bayesflow.networks import InferenceNetwork @@ -240,6 +241,6 @@ def decode(z): reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1) losses = maximum_likelihood_loss + self.beta * reconstruction_loss - loss = self.aggregate(losses, sample_weight) + loss = weighted_mean(losses, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 18fefdc91..c983a5cde 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -7,7 +7,13 @@ import numpy as np from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type, weighted_sum +from bayesflow.utils import ( + find_network, + keras_kwargs, + serialize_value_or_type, + deserialize_value_or_type, + weighted_mean, +) from ..inference_network import InferenceNetwork @@ -331,6 +337,6 @@ def compute_metrics( # Pseudo-huber loss, see [2], Section 3.3 loss = lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber) - loss = weighted_sum(loss, sample_weight) + loss = weighted_mean(loss, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index d50432ea6..403d92511 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -7,7 +7,7 @@ keras_kwargs, serialize_value_or_type, deserialize_value_or_type, - weighted_sum, + weighted_mean, ) from .actnorm import ActNorm @@ -167,6 +167,6 @@ def compute_metrics( base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) z, log_density = self(x, conditions=conditions, inverse=False, density=True) - loss = weighted_sum(-log_density, sample_weight) + loss = weighted_mean(-log_density, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index f8a39629b..85c338d82 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -13,7 +13,7 @@ optimal_transport, serialize_value_or_type, deserialize_value_or_type, - weighted_sum, + weighted_mean, ) from ..inference_network import InferenceNetwork @@ -260,6 +260,6 @@ def compute_metrics( predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training") loss = self.loss_fn(target_velocity, predicted_velocity) - loss = weighted_sum(loss, sample_weight) + loss = weighted_mean(loss, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/scores/normed_difference_score.py b/bayesflow/scores/normed_difference_score.py index 9659d3726..eb2795927 100644 --- a/bayesflow/scores/normed_difference_score.py +++ b/bayesflow/scores/normed_difference_score.py @@ -2,6 +2,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor +from bayesflow.utils import weighted_mean from .scoring_rule import ScoringRule @@ -55,7 +56,7 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor = """ estimates = estimates["value"] scores = keras.ops.absolute(estimates - targets) ** self.k - score = self.aggregate(scores, weights) + score = weighted_mean(scores, weights) return score def get_config(self): diff --git a/bayesflow/scores/parametric_distribution_score.py b/bayesflow/scores/parametric_distribution_score.py index bdd62d692..3ead3271f 100644 --- a/bayesflow/scores/parametric_distribution_score.py +++ b/bayesflow/scores/parametric_distribution_score.py @@ -1,6 +1,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor +from bayesflow.utils import weighted_mean from .scoring_rule import ScoringRule @@ -29,5 +30,5 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor = :math:`S(\hat p_\phi, \theta; k) = -\log(\hat p_\phi(\theta))` """ scores = -self.log_prob(x=targets, **estimates) - score = self.aggregate(scores, weights) + score = weighted_mean(scores, weights) return score diff --git a/bayesflow/scores/quantile_score.py b/bayesflow/scores/quantile_score.py index 2e3ec54ef..b05a35fc5 100644 --- a/bayesflow/scores/quantile_score.py +++ b/bayesflow/scores/quantile_score.py @@ -4,7 +4,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor -from bayesflow.utils import logging +from bayesflow.utils import logging, weighted_mean from bayesflow.links import OrderedQuantiles from .scoring_rule import ScoringRule @@ -39,7 +39,7 @@ def get_config(self): base_config = super().get_config() return base_config | self.config - def get_head_shapes_from_target_shape(self, target_shape: Shape): + def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, tuple]: # keras.saving.load_model sometimes passes target_shape as a list, so we force a conversion target_shape = tuple(target_shape) return dict(value=(len(self.q),) + target_shape[1:]) @@ -49,5 +49,5 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor = pointwise_differance = estimates - targets[:, None, :] scores = pointwise_differance * (keras.ops.cast(pointwise_differance > 0, float) - self._q[None, :, None]) scores = keras.ops.mean(scores, axis=1) - score = self.aggregate(scores, weights) + score = weighted_mean(scores, weights) return score diff --git a/bayesflow/scores/scoring_rule.py b/bayesflow/scores/scoring_rule.py index 87ad216e9..a1a3f5717 100644 --- a/bayesflow/scores/scoring_rule.py +++ b/bayesflow/scores/scoring_rule.py @@ -204,30 +204,3 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor) """ raise NotImplementedError - - def aggregate(self, scores: Tensor, weights: Tensor = None) -> Tensor: - """ - Computes the mean of **scores**, optionally applying **weights**. - - This function computes the mean value of the given scores. When weights are provided, - it first multiplies the scores by the weights and then computes the mean of the result. - If no weights are provided, it computes the mean of the scores. - - Parameters - ---------- - scores : Tensor - A tensor containing the scores to be aggregated. - weights : Tensor, optional (default - None) - A tensor of weights corresponding to each score. Must be the same shape as `scores`. - If not provided, the function returns the mean of `scores`. - - Returns - ------- - Tensor - The aggregated score computed as a weighted mean if **weights** is provided, - or as the simple mean of **scores** otherwise. - """ - - if weights is not None: - return keras.ops.mean(scores * weights) - return keras.ops.mean(scores) diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 5bdaea274..1e1092e0c 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -71,7 +71,7 @@ tree_concatenate, tree_stack, fill_triangular_matrix, - weighted_sum, + weighted_mean, ) from .classification import calibration_curve, confusion_matrix from .validators import check_lengths_same diff --git a/bayesflow/utils/tensor_utils.py b/bayesflow/utils/tensor_utils.py index cd6f6d4ca..4d89249b7 100644 --- a/bayesflow/utils/tensor_utils.py +++ b/bayesflow/utils/tensor_utils.py @@ -140,7 +140,7 @@ def pad(x: Tensor, value: float | Tensor, n: int, axis: int, side: str = "both") raise TypeError(f"Invalid side type {type(side)!r}. Must be str.") -def weighted_sum(elements: Tensor, weights: Tensor = None) -> Tensor: +def weighted_mean(elements: Tensor, weights: Tensor = None) -> Tensor: """ Compute the (optionally) weighted mean of the input tensor.