Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import MutableSequence, Sequence, Mapping
from collections.abc import Callable, MutableSequence, Sequence, Mapping

import numpy as np

Expand All @@ -24,6 +24,7 @@
NumpyTransform,
OneHot,
Rename,
SerializableCustomTransform,
Sqrt,
Standardize,
ToArray,
Expand Down Expand Up @@ -274,6 +275,88 @@ def apply(
self.transforms.append(transform)
return self

def apply_serializable(
self,
include: str | Sequence[str] = None,
*,
forward: Callable[[np.ndarray, ...], np.ndarray],
inverse: Callable[[np.ndarray, ...], np.ndarray],
predicate: Predicate = None,
exclude: str | Sequence[str] = None,
**kwargs,
):
"""Append a :py:class:`~transforms.SerializableCustomTransform` to the adapter.

Parameters
----------
forward : function, no lambda
Registered serializable function to transform the data in the forward pass.
For the adapter to be serializable, this function has to be serializable
as well (see Notes). Therefore, only proper functions and no lambda
functions can be used here.
inverse : function, no lambda
Registered serializable function to transform the data in the inverse pass.
For the adapter to be serializable, this function has to be serializable
as well (see Notes). Therefore, only proper functions and no lambda
functions can be used here.
predicate : Predicate, optional
Function that indicates which variables should be transformed.
include : str or Sequence of str, optional
Names of variables to include in the transform.
exclude : str or Sequence of str, optional
Names of variables to exclude from the transform.
**kwargs : dict
Additional keyword arguments passed to the transform.

Raises
------
ValueError
When the provided functions are not registered serializable functions.

Notes
-----
Important: The forward and inverse functions have to be registered with Keras.
To do so, use the `@keras.saving.register_keras_serializable` decorator.
They must also be registered (and identical) when loading the adapter
at a later point in time.

Examples
--------

The example below shows how to use the
`keras.saving.register_keras_serializable` decorator to
register functions with Keras. Note that for this simple
example, one usually would use the simpler :py:meth:`apply`
method.

>>> import keras
>>>
>>> @keras.saving.register_keras_serializable("custom")
>>> def forward_fn(x):
>>> return x**2
>>>
>>> @keras.saving.register_keras_serializable("custom")
>>> def inverse_fn(x):
>>> return x**0.5
>>>
>>> adapter = bf.Adapter().apply_serializable(
>>> "x",
>>> forward=forward_fn,
>>> inverse=inverse_fn,
>>> )
"""
transform = FilterTransform(
transform_constructor=SerializableCustomTransform,
predicate=predicate,
include=include,
exclude=exclude,
forward=forward,
inverse=inverse,
**kwargs,
)
self.transforms.append(transform)
return self

def as_set(self, keys: str | Sequence[str]):
"""Append an :py:class:`~transforms.AsSet` transform to the adapter.

Expand Down
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .one_hot import OneHot
from .rename import Rename
from .scale import Scale
from .serializable_custom_transform import SerializableCustomTransform
from .shift import Shift
from .sqrt import Sqrt
from .standardize import Standardize
Expand Down
183 changes: 183 additions & 0 deletions bayesflow/adapters/transforms/serializable_custom_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from collections.abc import Callable
import numpy as np
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
get_registered_name,
get_registered_object,
)
from .elementwise_transform import ElementwiseTransform
from ...utils import filter_kwargs
import inspect


@serializable(package="bayesflow.adapters")
class SerializableCustomTransform(ElementwiseTransform):
"""
Transforms a parameter using a pair of registered serializable forward and inverse functions.

Parameters
----------
forward : function, no lambda
Registered serializable function to transform the data in the forward pass.
For the adapter to be serializable, this function has to be serializable
as well (see Notes). Therefore, only proper functions and no lambda
functions can be used here.
inverse : function, no lambda
Function to transform the data in the inverse pass.
For the adapter to be serializable, this function has to be serializable
as well (see Notes). Therefore, only proper functions and no lambda
functions can be used here.

Raises
------
ValueError
When the provided functions are not registered serializable functions.

Notes
-----
Important: The forward and inverse functions have to be registered with Keras.
To do so, use the `@keras.saving.register_keras_serializable` decorator.
They must also be registered (and identical) when loading the adapter
at a later point in time.

"""

def __init__(
self,
*,
forward: Callable[[np.ndarray, ...], np.ndarray],
inverse: Callable[[np.ndarray, ...], np.ndarray],
):
super().__init__()

self._check_serializable(forward, label="forward")
self._check_serializable(inverse, label="inverse")
self._forward = forward
self._inverse = inverse

@classmethod
def _check_serializable(cls, function, label=""):
GENERAL_EXAMPLE_CODE = (
"The example code below shows the structure of a correctly decorated function:\n\n"
"```\n"
"import keras\n\n"
"@keras.saving.register_keras_serializable('custom')\n"
f"def my_{label}(...):\n"
" [your code goes here...]\n"
"```\n"
)
if function is None:
raise TypeError(

Check warning on line 72 in bayesflow/adapters/transforms/serializable_custom_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/serializable_custom_transform.py#L72

Added line #L72 was not covered by tests
f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}"
)
registered_name = get_registered_name(function)
# check if function is a lambda function
if registered_name == "<lambda>":
raise ValueError(
f"The provided function for '{label}' is a lambda function, "
"which cannot be serialized. "
"Please provide a registered serializable function by using the "
"@keras.saving.register_keras_serializable decorator."
f"\n{GENERAL_EXAMPLE_CODE}"
)
if inspect.ismethod(function):
raise ValueError(
f"The provided value for '{label}' is a method, not a function. "
"Methods cannot be serialized separately from their classes. "
"Please provide a registered serializable function instead by "
"moving the functionality to a function (i.e., outside of the class) and "
"using the @keras.saving.register_keras_serializable decorator."
f"\n{GENERAL_EXAMPLE_CODE}"
)
registered_object_for_name = get_registered_object(registered_name)
if registered_object_for_name is None:
try:
source_max_lines = 5
function_source_code = inspect.getsource(function).split("\n")
if len(function_source_code) > source_max_lines:
function_source_code = function_source_code[:source_max_lines] + [" [...]"]

Check warning on line 100 in bayesflow/adapters/transforms/serializable_custom_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/serializable_custom_transform.py#L100

Added line #L100 was not covered by tests

example_code = "For your provided function, this would look like this:\n\n"
example_code += "\n".join(
["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"]
+ function_source_code
+ ["```"]
)
except OSError:
example_code = GENERAL_EXAMPLE_CODE

Check warning on line 109 in bayesflow/adapters/transforms/serializable_custom_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/serializable_custom_transform.py#L108-L109

Added lines #L108 - L109 were not covered by tests
raise ValueError(
f"The provided function for '{label}' is not registered with Keras.\n"
"Please register the function using the "
"@keras.saving.register_keras_serializable decorator.\n"
f"{example_code}"
)
if registered_object_for_name is not function:
raise ValueError(

Check warning on line 117 in bayesflow/adapters/transforms/serializable_custom_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/serializable_custom_transform.py#L117

Added line #L117 was not covered by tests
f"The provided function for '{label}' does not match the function "
f"registered under its name '{registered_name}'. "
f"(registered function: {registered_object_for_name}, provided function: {function}). "
)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform":
if get_registered_object(config["forward"]["config"], custom_objects) is None:
provided_function_msg = ""
if config["_forward_source_code"]:
provided_function_msg = (
f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```"
)
raise TypeError(
"\n\nPLEASE READ HERE:\n"
"-----------------\n"
"The forward function that was provided as `forward` "
"is not registered with Keras, making deserialization impossible. "
f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original "
"function before loading your model."
f"{provided_function_msg}"
)
if get_registered_object(config["inverse"]["config"], custom_objects) is None:
provided_function_msg = ""
if config["_inverse_source_code"]:
provided_function_msg = (
f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```"
)
raise TypeError(
"\n\nPLEASE READ HERE:\n"
"-----------------\n"
"The inverse function that was provided as `inverse` "
"is not registered with Keras, making deserialization impossible. "
f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original "
"function before loading your model."
f"{provided_function_msg}"
)
forward = deserialize(config["forward"], custom_objects)
inverse = deserialize(config["inverse"], custom_objects)
return cls(
forward=forward,
inverse=inverse,
)

def get_config(self) -> dict:
forward_source_code = inverse_source_code = None
try:
forward_source_code = inspect.getsource(self._forward)
inverse_source_code = inspect.getsource(self._inverse)
except OSError:
pass

Check warning on line 168 in bayesflow/adapters/transforms/serializable_custom_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/serializable_custom_transform.py#L167-L168

Added lines #L167 - L168 were not covered by tests
return {
"forward": serialize(self._forward),
"inverse": serialize(self._inverse),
"_forward_source_code": forward_source_code,
"_inverse_source_code": inverse_source_code,
}

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
# filter kwargs so that other transform args like batch_size, strict, ... are not passed through
kwargs = filter_kwargs(kwargs, self._forward)
return self._forward(data, **kwargs)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
kwargs = filter_kwargs(kwargs, self._inverse)
return self._inverse(data, **kwargs)
6 changes: 6 additions & 0 deletions tests/test_adapters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
@pytest.fixture()
def adapter():
from bayesflow.adapters import Adapter
import keras

@keras.saving.register_keras_serializable("custom")
def serializable_fn(x):
return x

d = (
Adapter()
Expand All @@ -20,6 +25,7 @@ def adapter():
.constrain("p2", lower=0)
.apply(include="p2", forward="exp", inverse="log")
.apply(include="p2", forward="log1p")
.apply_serializable(include="x", forward=serializable_fn, inverse=serializable_fn)
.scale("x", by=[-1, 2])
.shift("x", by=2)
.standardize(exclude=["t1", "t2", "o1"])
Expand Down
67 changes: 67 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
serialize_keras_object as serialize,
)
import numpy as np
import pytest


def test_cycle_consistency(adapter, random_data):
Expand Down Expand Up @@ -110,3 +111,69 @@ def test_simple_transforms(random_data):
assert np.allclose(inverse["t1"], random_data["t1"])

assert np.allclose(inverse["p1"], random_data["p1"])


def test_custom_transform():
# test that transform raises errors in all relevant cases
import keras
from bayesflow.adapters.transforms import SerializableCustomTransform
from copy import deepcopy

class A:
@classmethod
def fn(cls, x):
return x

def not_registered_fn(x):
return x

@keras.saving.register_keras_serializable("custom")
def registered_fn(x):
return x

@keras.saving.register_keras_serializable("custom")
def registered_but_changed(x):
return x

def registered_but_changed(x): # noqa: F811
return 2 * x

# method instead of function provided
with pytest.raises(ValueError):
SerializableCustomTransform(forward=A.fn, inverse=registered_fn)
with pytest.raises(ValueError):
SerializableCustomTransform(forward=registered_fn, inverse=A.fn)

# lambda function provided
with pytest.raises(ValueError):
SerializableCustomTransform(forward=lambda x: x, inverse=registered_fn)
with pytest.raises(ValueError):
SerializableCustomTransform(forward=registered_fn, inverse=lambda x: x)

# unregistered function provided
with pytest.raises(ValueError):
SerializableCustomTransform(forward=not_registered_fn, inverse=registered_fn)
with pytest.raises(ValueError):
SerializableCustomTransform(forward=registered_fn, inverse=not_registered_fn)

# function does not match registered function
with pytest.raises(ValueError):
SerializableCustomTransform(forward=registered_but_changed, inverse=registered_fn)
with pytest.raises(ValueError):
SerializableCustomTransform(forward=registered_fn, inverse=registered_but_changed)

transform = SerializableCustomTransform(forward=registered_fn, inverse=registered_fn)
serialized_transform = keras.saving.serialize_keras_object(transform)
keras.saving.deserialize_keras_object(serialized_transform)

# modify name of the forward function so that it cannot be found
corrupt_serialized_transform = deepcopy(serialized_transform)
corrupt_serialized_transform["config"]["forward"]["config"] = "nonexistent"
with pytest.raises(TypeError):
keras.saving.deserialize_keras_object(corrupt_serialized_transform)

# modify name of the inverse transform so that it cannot be found
corrupt_serialized_transform = deepcopy(serialized_transform)
corrupt_serialized_transform["config"]["inverse"]["config"] = "nonexistent"
with pytest.raises(TypeError):
keras.saving.deserialize_keras_object(corrupt_serialized_transform)