diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index c3a438761..73af0c47c 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -1,4 +1,4 @@ -from collections.abc import MutableSequence, Sequence, Mapping +from collections.abc import Callable, MutableSequence, Sequence, Mapping import numpy as np @@ -24,6 +24,7 @@ NumpyTransform, OneHot, Rename, + SerializableCustomTransform, Sqrt, Standardize, ToArray, @@ -274,6 +275,88 @@ def apply( self.transforms.append(transform) return self + def apply_serializable( + self, + include: str | Sequence[str] = None, + *, + forward: Callable[[np.ndarray, ...], np.ndarray], + inverse: Callable[[np.ndarray, ...], np.ndarray], + predicate: Predicate = None, + exclude: str | Sequence[str] = None, + **kwargs, + ): + """Append a :py:class:`~transforms.SerializableCustomTransform` to the adapter. + + Parameters + ---------- + forward : function, no lambda + Registered serializable function to transform the data in the forward pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + inverse : function, no lambda + Registered serializable function to transform the data in the inverse pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + predicate : Predicate, optional + Function that indicates which variables should be transformed. + include : str or Sequence of str, optional + Names of variables to include in the transform. + exclude : str or Sequence of str, optional + Names of variables to exclude from the transform. + **kwargs : dict + Additional keyword arguments passed to the transform. + + Raises + ------ + ValueError + When the provided functions are not registered serializable functions. + + Notes + ----- + Important: The forward and inverse functions have to be registered with Keras. + To do so, use the `@keras.saving.register_keras_serializable` decorator. + They must also be registered (and identical) when loading the adapter + at a later point in time. + + Examples + -------- + + The example below shows how to use the + `keras.saving.register_keras_serializable` decorator to + register functions with Keras. Note that for this simple + example, one usually would use the simpler :py:meth:`apply` + method. + + >>> import keras + >>> + >>> @keras.saving.register_keras_serializable("custom") + >>> def forward_fn(x): + >>> return x**2 + >>> + >>> @keras.saving.register_keras_serializable("custom") + >>> def inverse_fn(x): + >>> return x**0.5 + >>> + >>> adapter = bf.Adapter().apply_serializable( + >>> "x", + >>> forward=forward_fn, + >>> inverse=inverse_fn, + >>> ) + """ + transform = FilterTransform( + transform_constructor=SerializableCustomTransform, + predicate=predicate, + include=include, + exclude=exclude, + forward=forward, + inverse=inverse, + **kwargs, + ) + self.transforms.append(transform) + return self + def as_set(self, keys: str | Sequence[str]): """Append an :py:class:`~transforms.AsSet` transform to the adapter. diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index b3e95a494..81e9f665f 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -15,6 +15,7 @@ from .one_hot import OneHot from .rename import Rename from .scale import Scale +from .serializable_custom_transform import SerializableCustomTransform from .shift import Shift from .sqrt import Sqrt from .standardize import Standardize diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py new file mode 100644 index 000000000..75d588afd --- /dev/null +++ b/bayesflow/adapters/transforms/serializable_custom_transform.py @@ -0,0 +1,183 @@ +from collections.abc import Callable +import numpy as np +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, + get_registered_name, + get_registered_object, +) +from .elementwise_transform import ElementwiseTransform +from ...utils import filter_kwargs +import inspect + + +@serializable(package="bayesflow.adapters") +class SerializableCustomTransform(ElementwiseTransform): + """ + Transforms a parameter using a pair of registered serializable forward and inverse functions. + + Parameters + ---------- + forward : function, no lambda + Registered serializable function to transform the data in the forward pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + inverse : function, no lambda + Function to transform the data in the inverse pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + + Raises + ------ + ValueError + When the provided functions are not registered serializable functions. + + Notes + ----- + Important: The forward and inverse functions have to be registered with Keras. + To do so, use the `@keras.saving.register_keras_serializable` decorator. + They must also be registered (and identical) when loading the adapter + at a later point in time. + + """ + + def __init__( + self, + *, + forward: Callable[[np.ndarray, ...], np.ndarray], + inverse: Callable[[np.ndarray, ...], np.ndarray], + ): + super().__init__() + + self._check_serializable(forward, label="forward") + self._check_serializable(inverse, label="inverse") + self._forward = forward + self._inverse = inverse + + @classmethod + def _check_serializable(cls, function, label=""): + GENERAL_EXAMPLE_CODE = ( + "The example code below shows the structure of a correctly decorated function:\n\n" + "```\n" + "import keras\n\n" + "@keras.saving.register_keras_serializable('custom')\n" + f"def my_{label}(...):\n" + " [your code goes here...]\n" + "```\n" + ) + if function is None: + raise TypeError( + f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}" + ) + registered_name = get_registered_name(function) + # check if function is a lambda function + if registered_name == "": + raise ValueError( + f"The provided function for '{label}' is a lambda function, " + "which cannot be serialized. " + "Please provide a registered serializable function by using the " + "@keras.saving.register_keras_serializable decorator." + f"\n{GENERAL_EXAMPLE_CODE}" + ) + if inspect.ismethod(function): + raise ValueError( + f"The provided value for '{label}' is a method, not a function. " + "Methods cannot be serialized separately from their classes. " + "Please provide a registered serializable function instead by " + "moving the functionality to a function (i.e., outside of the class) and " + "using the @keras.saving.register_keras_serializable decorator." + f"\n{GENERAL_EXAMPLE_CODE}" + ) + registered_object_for_name = get_registered_object(registered_name) + if registered_object_for_name is None: + try: + source_max_lines = 5 + function_source_code = inspect.getsource(function).split("\n") + if len(function_source_code) > source_max_lines: + function_source_code = function_source_code[:source_max_lines] + [" [...]"] + + example_code = "For your provided function, this would look like this:\n\n" + example_code += "\n".join( + ["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"] + + function_source_code + + ["```"] + ) + except OSError: + example_code = GENERAL_EXAMPLE_CODE + raise ValueError( + f"The provided function for '{label}' is not registered with Keras.\n" + "Please register the function using the " + "@keras.saving.register_keras_serializable decorator.\n" + f"{example_code}" + ) + if registered_object_for_name is not function: + raise ValueError( + f"The provided function for '{label}' does not match the function " + f"registered under its name '{registered_name}'. " + f"(registered function: {registered_object_for_name}, provided function: {function}). " + ) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform": + if get_registered_object(config["forward"]["config"], custom_objects) is None: + provided_function_msg = "" + if config["_forward_source_code"]: + provided_function_msg = ( + f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```" + ) + raise TypeError( + "\n\nPLEASE READ HERE:\n" + "-----------------\n" + "The forward function that was provided as `forward` " + "is not registered with Keras, making deserialization impossible. " + f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original " + "function before loading your model." + f"{provided_function_msg}" + ) + if get_registered_object(config["inverse"]["config"], custom_objects) is None: + provided_function_msg = "" + if config["_inverse_source_code"]: + provided_function_msg = ( + f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```" + ) + raise TypeError( + "\n\nPLEASE READ HERE:\n" + "-----------------\n" + "The inverse function that was provided as `inverse` " + "is not registered with Keras, making deserialization impossible. " + f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original " + "function before loading your model." + f"{provided_function_msg}" + ) + forward = deserialize(config["forward"], custom_objects) + inverse = deserialize(config["inverse"], custom_objects) + return cls( + forward=forward, + inverse=inverse, + ) + + def get_config(self) -> dict: + forward_source_code = inverse_source_code = None + try: + forward_source_code = inspect.getsource(self._forward) + inverse_source_code = inspect.getsource(self._inverse) + except OSError: + pass + return { + "forward": serialize(self._forward), + "inverse": serialize(self._inverse), + "_forward_source_code": forward_source_code, + "_inverse_source_code": inverse_source_code, + } + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + # filter kwargs so that other transform args like batch_size, strict, ... are not passed through + kwargs = filter_kwargs(kwargs, self._forward) + return self._forward(data, **kwargs) + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + kwargs = filter_kwargs(kwargs, self._inverse) + return self._inverse(data, **kwargs) diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index abd5797bd..d379d7cc4 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -5,6 +5,11 @@ @pytest.fixture() def adapter(): from bayesflow.adapters import Adapter + import keras + + @keras.saving.register_keras_serializable("custom") + def serializable_fn(x): + return x d = ( Adapter() @@ -20,6 +25,7 @@ def adapter(): .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log") .apply(include="p2", forward="log1p") + .apply_serializable(include="x", forward=serializable_fn, inverse=serializable_fn) .scale("x", by=[-1, 2]) .shift("x", by=2) .standardize(exclude=["t1", "t2", "o1"]) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 3dea0baf4..0d17c419f 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -3,6 +3,7 @@ serialize_keras_object as serialize, ) import numpy as np +import pytest def test_cycle_consistency(adapter, random_data): @@ -110,3 +111,69 @@ def test_simple_transforms(random_data): assert np.allclose(inverse["t1"], random_data["t1"]) assert np.allclose(inverse["p1"], random_data["p1"]) + + +def test_custom_transform(): + # test that transform raises errors in all relevant cases + import keras + from bayesflow.adapters.transforms import SerializableCustomTransform + from copy import deepcopy + + class A: + @classmethod + def fn(cls, x): + return x + + def not_registered_fn(x): + return x + + @keras.saving.register_keras_serializable("custom") + def registered_fn(x): + return x + + @keras.saving.register_keras_serializable("custom") + def registered_but_changed(x): + return x + + def registered_but_changed(x): # noqa: F811 + return 2 * x + + # method instead of function provided + with pytest.raises(ValueError): + SerializableCustomTransform(forward=A.fn, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=A.fn) + + # lambda function provided + with pytest.raises(ValueError): + SerializableCustomTransform(forward=lambda x: x, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=lambda x: x) + + # unregistered function provided + with pytest.raises(ValueError): + SerializableCustomTransform(forward=not_registered_fn, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=not_registered_fn) + + # function does not match registered function + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_but_changed, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=registered_but_changed) + + transform = SerializableCustomTransform(forward=registered_fn, inverse=registered_fn) + serialized_transform = keras.saving.serialize_keras_object(transform) + keras.saving.deserialize_keras_object(serialized_transform) + + # modify name of the forward function so that it cannot be found + corrupt_serialized_transform = deepcopy(serialized_transform) + corrupt_serialized_transform["config"]["forward"]["config"] = "nonexistent" + with pytest.raises(TypeError): + keras.saving.deserialize_keras_object(corrupt_serialized_transform) + + # modify name of the inverse transform so that it cannot be found + corrupt_serialized_transform = deepcopy(serialized_transform) + corrupt_serialized_transform["config"]["inverse"]["config"] = "nonexistent" + with pytest.raises(TypeError): + keras.saving.deserialize_keras_object(corrupt_serialized_transform)