Skip to content

Commit 2ba93d4

Browse files
committed
Add class for custom transforms to adapter.
This commit reintroduces the features that were present in `LambdaTransform`, but only allowing registered functions. While being stricter, that allows for closer scaffolding and raising errors early on, so that users cannot provide functions that will not be (de)serializable later on. As there are a few failure modes, the focus is on providing detailed error messages to enable users to solve problems without external help.
1 parent 1284694 commit 2ba93d4

File tree

5 files changed

+346
-1
lines changed

5 files changed

+346
-1
lines changed

bayesflow/adapters/adapter.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import MutableSequence, Sequence, Mapping
1+
from collections.abc import Callable, MutableSequence, Sequence, Mapping
22

33
import numpy as np
44

@@ -24,6 +24,7 @@
2424
NumpyTransform,
2525
OneHot,
2626
Rename,
27+
SerializableCustomTransform,
2728
Sqrt,
2829
Standardize,
2930
ToArray,
@@ -283,6 +284,88 @@ def apply(
283284
self.transforms.append(transform)
284285
return self
285286

287+
def apply_serializable(
288+
self,
289+
include: str | Sequence[str] = None,
290+
*,
291+
serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray],
292+
serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray],
293+
predicate: Predicate = None,
294+
exclude: str | Sequence[str] = None,
295+
**kwargs,
296+
):
297+
"""Append a :py:class:`~transforms.SerializableCustomTransform` to the adapter.
298+
299+
Parameters
300+
----------
301+
serializable_forward_fn : function, no lambda
302+
Registered serializable function to transform the data in the forward pass.
303+
For the adapter to be serializable, this function has to be serializable
304+
as well (see Notes). Therefore, only proper functions and no lambda
305+
functions can be used here.
306+
serializable_inverse_fn : function, no lambda
307+
Registered serializable function to transform the data in the inverse pass.
308+
For the adapter to be serializable, this function has to be serializable
309+
as well (see Notes). Therefore, only proper functions and no lambda
310+
functions can be used here.
311+
predicate : Predicate, optional
312+
Function that indicates which variables should be transformed.
313+
include : str or Sequence of str, optional
314+
Names of variables to include in the transform.
315+
exclude : str or Sequence of str, optional
316+
Names of variables to exclude from the transform.
317+
**kwargs : dict
318+
Additional keyword arguments passed to the transform.
319+
320+
Raises
321+
------
322+
ValueError
323+
When the provided functions are not registered serializable functions.
324+
325+
Notes
326+
-----
327+
Important: The forward and inverse functions have to be registered with Keras.
328+
To do so, use the `@keras.saving.register_keras_serializable` decorator.
329+
They must also be registered (and identical) when loading the adapter
330+
at a later point in time.
331+
332+
Examples
333+
--------
334+
335+
The example below shows how to use the
336+
`keras.saving.register_keras_serializable` decorator to
337+
register functions with Keras. Note that for this simple
338+
example, one usually would use the simpler :py:meth:`apply`
339+
method.
340+
341+
>>> import keras
342+
>>>
343+
>>> @keras.saving.register_keras_serializable("custom")
344+
>>> def forward_fn(x):
345+
>>> return x**2
346+
>>>
347+
>>> @keras.saving.register_keras_serializable("custom")
348+
>>> def inverse_fn(x):
349+
>>> return x**0.5
350+
>>>
351+
>>> adapter = bf.Adapter().apply_serializable(
352+
>>> "x",
353+
>>> serializable_forward_fn=forward_fn,
354+
>>> serializable_inverse_fn=inverse_fn,
355+
>>> )
356+
"""
357+
transform = FilterTransform(
358+
transform_constructor=SerializableCustomTransform,
359+
predicate=predicate,
360+
include=include,
361+
exclude=exclude,
362+
serializable_forward_fn=serializable_forward_fn,
363+
serializable_inverse_fn=serializable_inverse_fn,
364+
**kwargs,
365+
)
366+
self.transforms.append(transform)
367+
return self
368+
286369
def as_set(self, keys: str | Sequence[str]):
287370
"""Append an :py:class:`~transforms.AsSet` transform to the adapter.
288371

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .one_hot import OneHot
1616
from .rename import Rename
1717
from .scale import Scale
18+
from .serializable_custom_transform import SerializableCustomTransform
1819
from .shift import Shift
1920
from .sqrt import Sqrt
2021
from .standardize import Standardize
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from collections.abc import Callable
2+
import numpy as np
3+
from keras.saving import (
4+
deserialize_keras_object as deserialize,
5+
register_keras_serializable as serializable,
6+
serialize_keras_object as serialize,
7+
get_registered_name,
8+
get_registered_object,
9+
)
10+
from .elementwise_transform import ElementwiseTransform
11+
from ...utils import filter_kwargs
12+
import inspect
13+
14+
15+
@serializable(package="bayesflow.adapters")
16+
class SerializableCustomTransform(ElementwiseTransform):
17+
"""
18+
Transforms a parameter using a pair of registered serializable forward and inverse functions.
19+
20+
Parameters
21+
----------
22+
serializable_forward_fn : function, no lambda
23+
Registered serializable function to transform the data in the forward pass.
24+
For the adapter to be serializable, this function has to be serializable
25+
as well (see Notes). Therefore, only proper functions and no lambda
26+
functions can be used here.
27+
serializable_inverse_fn : function, no lambda
28+
Function to transform the data in the inverse pass.
29+
For the adapter to be serializable, this function has to be serializable
30+
as well (see Notes). Therefore, only proper functions and no lambda
31+
functions can be used here.
32+
33+
Raises
34+
------
35+
ValueError
36+
When the provided functions are not registered serializable functions.
37+
38+
Notes
39+
-----
40+
Important: The forward and inverse functions have to be registered with Keras.
41+
To do so, use the `@keras.saving.register_keras_serializable` decorator.
42+
They must also be registered (and identical) when loading the adapter
43+
at a later point in time.
44+
45+
"""
46+
47+
def __init__(
48+
self,
49+
*,
50+
serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray],
51+
serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray],
52+
):
53+
super().__init__()
54+
55+
self._check_serializable(serializable_forward_fn, label="serializable_forward_fn")
56+
self._check_serializable(serializable_inverse_fn, label="serializable_inverse_fn")
57+
self._forward = serializable_forward_fn
58+
self._inverse = serializable_inverse_fn
59+
60+
@classmethod
61+
def _check_serializable(cls, function, label=""):
62+
GENERAL_EXAMPLE_CODE = f"""The example code below shows the structure of a correctly decorated function:
63+
64+
```
65+
import keras
66+
67+
@keras.saving.register_keras_serializable('custom')
68+
def my_{label}(...):
69+
[your code goes here...]
70+
```
71+
"""
72+
if function is None:
73+
raise TypeError(
74+
f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}"
75+
)
76+
registered_name = get_registered_name(function)
77+
# check if function is a lambda function
78+
if registered_name == "<lambda>":
79+
raise ValueError(
80+
f"The provided function for '{label}' is a lambda function, "
81+
"which cannot be serialized. "
82+
"Please provide a registered serializable function by using the "
83+
"@keras.saving.register_keras_serializable decorator."
84+
f"\n{GENERAL_EXAMPLE_CODE}"
85+
)
86+
if inspect.ismethod(function):
87+
raise ValueError(
88+
f"The provided value for '{label}' is a method, not a function. "
89+
"Methods cannot be serialized separately from their classes. "
90+
"Please provide a registered serializable function instead by "
91+
"moving the functionality to a function (i.e., outside of the class) and "
92+
"using the @keras.saving.register_keras_serializable decorator."
93+
f"\n{GENERAL_EXAMPLE_CODE}"
94+
)
95+
registered_object_for_name = get_registered_object(registered_name)
96+
if registered_object_for_name is None:
97+
try:
98+
source_max_lines = 5
99+
function_source_code = inspect.getsource(function).split("\n")
100+
if len(function_source_code) > source_max_lines:
101+
function_source_code = function_source_code[:source_max_lines] + [" [...]"]
102+
103+
example_code = "For your provided function, this would look like this:\n\n"
104+
example_code += "\n".join(
105+
["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"]
106+
+ function_source_code
107+
+ ["```"]
108+
)
109+
except OSError:
110+
example_code = GENERAL_EXAMPLE_CODE
111+
raise ValueError(
112+
f"The provided function for '{label}' is not registered with Keras.\n"
113+
"Please register the function using the "
114+
"@keras.saving.register_keras_serializable decorator.\n"
115+
f"{example_code}"
116+
)
117+
if registered_object_for_name is not function:
118+
raise ValueError(
119+
f"The provided function for '{label}' does not match the function "
120+
f"registered under its name '{registered_name}'. "
121+
f"(registered function: {registered_object_for_name}, provided function: {function}). "
122+
)
123+
124+
@classmethod
125+
def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform":
126+
if get_registered_object(config["forward"]["config"], custom_objects) is None:
127+
provided_function_msg = ""
128+
if config["_forward_source_code"]:
129+
provided_function_msg = (
130+
f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```"
131+
)
132+
raise TypeError(
133+
"\n\nPLEASE READ HERE:\n"
134+
"-----------------\n"
135+
"The forward function that was provided as `serializable_forward_fn` "
136+
"is not registered with Keras, making deserialization impossible. "
137+
f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original "
138+
"function before loading your model."
139+
f"{provided_function_msg}"
140+
)
141+
if get_registered_object(config["inverse"]["config"], custom_objects) is None:
142+
provided_function_msg = ""
143+
if config["_inverse_source_code"]:
144+
provided_function_msg = (
145+
f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```"
146+
)
147+
raise TypeError(
148+
"\n\nPLEASE READ HERE:\n"
149+
"-----------------\n"
150+
"The inverse function that was provided as `serializable_inverse_fn` "
151+
"is not registered with Keras, making deserialization impossible. "
152+
f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original "
153+
"function before loading your model."
154+
f"{provided_function_msg}"
155+
)
156+
forward = deserialize(config["forward"], custom_objects)
157+
inverse = deserialize(config["inverse"], custom_objects)
158+
return cls(
159+
serializable_forward_fn=forward,
160+
serializable_inverse_fn=inverse,
161+
)
162+
163+
def get_config(self) -> dict:
164+
forward_source_code = inverse_source_code = None
165+
try:
166+
forward_source_code = inspect.getsource(self._forward)
167+
inverse_source_code = inspect.getsource(self._inverse)
168+
except OSError:
169+
pass
170+
return {
171+
"forward": serialize(self._forward),
172+
"inverse": serialize(self._inverse),
173+
"_forward_source_code": forward_source_code,
174+
"_inverse_source_code": inverse_source_code,
175+
}
176+
177+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
178+
# filter kwargs so that other transform args like batch_size, strict, ... are not passed through
179+
kwargs = filter_kwargs(kwargs, self._forward)
180+
return self._forward(data, **kwargs)
181+
182+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
183+
kwargs = filter_kwargs(kwargs, self._inverse)
184+
return self._inverse(data, **kwargs)

tests/test_adapters/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
@pytest.fixture()
66
def adapter():
77
from bayesflow.adapters import Adapter
8+
import keras
9+
10+
@keras.saving.register_keras_serializable("custom")
11+
def serializable_fn(x):
12+
return x
813

914
d = (
1015
Adapter()
@@ -20,6 +25,9 @@ def adapter():
2025
.constrain("p2", lower=0)
2126
.apply(include="p2", forward="exp", inverse="log")
2227
.apply(include="p2", forward="log1p")
28+
.apply_serializable(
29+
include="x", serializable_forward_fn=serializable_fn, serializable_inverse_fn=serializable_fn
30+
)
2331
.scale("x", by=[-1, 2])
2432
.shift("x", by=2)
2533
.standardize(exclude=["t1", "t2", "o1"])

tests/test_adapters/test_adapters.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
serialize_keras_object as serialize,
44
)
55
import numpy as np
6+
import pytest
67

78

89
def test_cycle_consistency(adapter, random_data):
@@ -110,3 +111,71 @@ def test_simple_transforms(random_data):
110111
assert np.allclose(inverse["t1"], random_data["t1"])
111112

112113
assert np.allclose(inverse["p1"], random_data["p1"])
114+
115+
116+
def test_custom_transform():
117+
# test that transform raises errors in all relevant cases
118+
import keras
119+
from bayesflow.adapters.transforms import SerializableCustomTransform
120+
from copy import deepcopy
121+
122+
class A:
123+
@classmethod
124+
def fn(cls, x):
125+
return x
126+
127+
def not_registered_fn(x):
128+
return x
129+
130+
@keras.saving.register_keras_serializable("custom")
131+
def registered_fn(x):
132+
return x
133+
134+
@keras.saving.register_keras_serializable("custom")
135+
def registered_but_changed(x):
136+
return x
137+
138+
def registered_but_changed(x): # noqa: F811
139+
return 2 * x
140+
141+
# method instead of function provided
142+
with pytest.raises(ValueError):
143+
SerializableCustomTransform(serializable_forward_fn=A.fn, serializable_inverse_fn=registered_fn)
144+
SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=A.fn)
145+
146+
# lambda function provided
147+
with pytest.raises(ValueError):
148+
SerializableCustomTransform(serializable_forward_fn=lambda x: x, serializable_inverse_fn=registered_fn)
149+
SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=lambda x: x)
150+
151+
# unregistered function provided
152+
with pytest.raises(ValueError):
153+
SerializableCustomTransform(serializable_forward_fn=not_registered_fn, serializable_inverse_fn=registered_fn)
154+
SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=not_registered_fn)
155+
156+
# function does not match registered function
157+
with pytest.raises(ValueError):
158+
SerializableCustomTransform(
159+
serializable_forward_fn=registered_but_changed, serializable_inverse_fn=registered_fn
160+
)
161+
SerializableCustomTransform(
162+
serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_but_changed
163+
)
164+
165+
transform = SerializableCustomTransform(
166+
serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_fn
167+
)
168+
serialized_transform = keras.saving.serialize_keras_object(transform)
169+
keras.saving.deserialize_keras_object(serialized_transform)
170+
171+
# modify name of the forward function so that it cannot be found
172+
corrupt_serialized_transform = deepcopy(serialized_transform)
173+
corrupt_serialized_transform["config"]["forward"]["config"] = "nonexistent"
174+
with pytest.raises(TypeError):
175+
keras.saving.deserialize_keras_object(corrupt_serialized_transform)
176+
177+
# modify name of the inverse transform so that it cannot be found
178+
corrupt_serialized_transform = deepcopy(serialized_transform)
179+
corrupt_serialized_transform["config"]["inverse"]["config"] = "nonexistent"
180+
with pytest.raises(TypeError):
181+
keras.saving.deserialize_keras_object(corrupt_serialized_transform)

0 commit comments

Comments
 (0)