Skip to content

Commit 8fdfde9

Browse files
committed
custom transform: less verbose naming, fix tests
1 parent 2ba93d4 commit 8fdfde9

File tree

4 files changed

+44
-49
lines changed

4 files changed

+44
-49
lines changed

bayesflow/adapters/adapter.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ def apply_serializable(
288288
self,
289289
include: str | Sequence[str] = None,
290290
*,
291-
serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray],
292-
serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray],
291+
forward: Callable[[np.ndarray, ...], np.ndarray],
292+
inverse: Callable[[np.ndarray, ...], np.ndarray],
293293
predicate: Predicate = None,
294294
exclude: str | Sequence[str] = None,
295295
**kwargs,
@@ -298,12 +298,12 @@ def apply_serializable(
298298
299299
Parameters
300300
----------
301-
serializable_forward_fn : function, no lambda
301+
forward : function, no lambda
302302
Registered serializable function to transform the data in the forward pass.
303303
For the adapter to be serializable, this function has to be serializable
304304
as well (see Notes). Therefore, only proper functions and no lambda
305305
functions can be used here.
306-
serializable_inverse_fn : function, no lambda
306+
inverse : function, no lambda
307307
Registered serializable function to transform the data in the inverse pass.
308308
For the adapter to be serializable, this function has to be serializable
309309
as well (see Notes). Therefore, only proper functions and no lambda
@@ -350,17 +350,17 @@ def apply_serializable(
350350
>>>
351351
>>> adapter = bf.Adapter().apply_serializable(
352352
>>> "x",
353-
>>> serializable_forward_fn=forward_fn,
354-
>>> serializable_inverse_fn=inverse_fn,
353+
>>> forward=forward_fn,
354+
>>> inverse=inverse_fn,
355355
>>> )
356356
"""
357357
transform = FilterTransform(
358358
transform_constructor=SerializableCustomTransform,
359359
predicate=predicate,
360360
include=include,
361361
exclude=exclude,
362-
serializable_forward_fn=serializable_forward_fn,
363-
serializable_inverse_fn=serializable_inverse_fn,
362+
forward=forward,
363+
inverse=inverse,
364364
**kwargs,
365365
)
366366
self.transforms.append(transform)

bayesflow/adapters/transforms/serializable_custom_transform.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ class SerializableCustomTransform(ElementwiseTransform):
1919
2020
Parameters
2121
----------
22-
serializable_forward_fn : function, no lambda
22+
forward : function, no lambda
2323
Registered serializable function to transform the data in the forward pass.
2424
For the adapter to be serializable, this function has to be serializable
2525
as well (see Notes). Therefore, only proper functions and no lambda
2626
functions can be used here.
27-
serializable_inverse_fn : function, no lambda
27+
inverse : function, no lambda
2828
Function to transform the data in the inverse pass.
2929
For the adapter to be serializable, this function has to be serializable
3030
as well (see Notes). Therefore, only proper functions and no lambda
@@ -47,28 +47,27 @@ class SerializableCustomTransform(ElementwiseTransform):
4747
def __init__(
4848
self,
4949
*,
50-
serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray],
51-
serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray],
50+
forward: Callable[[np.ndarray, ...], np.ndarray],
51+
inverse: Callable[[np.ndarray, ...], np.ndarray],
5252
):
5353
super().__init__()
5454

55-
self._check_serializable(serializable_forward_fn, label="serializable_forward_fn")
56-
self._check_serializable(serializable_inverse_fn, label="serializable_inverse_fn")
57-
self._forward = serializable_forward_fn
58-
self._inverse = serializable_inverse_fn
55+
self._check_serializable(forward, label="forward")
56+
self._check_serializable(inverse, label="inverse")
57+
self._forward = forward
58+
self._inverse = inverse
5959

6060
@classmethod
6161
def _check_serializable(cls, function, label=""):
62-
GENERAL_EXAMPLE_CODE = f"""The example code below shows the structure of a correctly decorated function:
63-
64-
```
65-
import keras
66-
67-
@keras.saving.register_keras_serializable('custom')
68-
def my_{label}(...):
69-
[your code goes here...]
70-
```
71-
"""
62+
GENERAL_EXAMPLE_CODE = (
63+
"The example code below shows the structure of a correctly decorated function:\n\n"
64+
"```\n"
65+
"import keras\n\n"
66+
"@keras.saving.register_keras_serializable('custom')\n"
67+
f"def my_{label}(...):\n"
68+
" [your code goes here...]\n"
69+
"```\n"
70+
)
7271
if function is None:
7372
raise TypeError(
7473
f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}"
@@ -132,7 +131,7 @@ def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTr
132131
raise TypeError(
133132
"\n\nPLEASE READ HERE:\n"
134133
"-----------------\n"
135-
"The forward function that was provided as `serializable_forward_fn` "
134+
"The forward function that was provided as `forward` "
136135
"is not registered with Keras, making deserialization impossible. "
137136
f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original "
138137
"function before loading your model."
@@ -147,7 +146,7 @@ def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTr
147146
raise TypeError(
148147
"\n\nPLEASE READ HERE:\n"
149148
"-----------------\n"
150-
"The inverse function that was provided as `serializable_inverse_fn` "
149+
"The inverse function that was provided as `inverse` "
151150
"is not registered with Keras, making deserialization impossible. "
152151
f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original "
153152
"function before loading your model."
@@ -156,8 +155,8 @@ def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTr
156155
forward = deserialize(config["forward"], custom_objects)
157156
inverse = deserialize(config["inverse"], custom_objects)
158157
return cls(
159-
serializable_forward_fn=forward,
160-
serializable_inverse_fn=inverse,
158+
forward=forward,
159+
inverse=inverse,
161160
)
162161

163162
def get_config(self) -> dict:

tests/test_adapters/conftest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ def serializable_fn(x):
2525
.constrain("p2", lower=0)
2626
.apply(include="p2", forward="exp", inverse="log")
2727
.apply(include="p2", forward="log1p")
28-
.apply_serializable(
29-
include="x", serializable_forward_fn=serializable_fn, serializable_inverse_fn=serializable_fn
30-
)
28+
.apply_serializable(include="x", forward=serializable_fn, inverse=serializable_fn)
3129
.scale("x", by=[-1, 2])
3230
.shift("x", by=2)
3331
.standardize(exclude=["t1", "t2", "o1"])

tests/test_adapters/test_adapters.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -140,31 +140,29 @@ def registered_but_changed(x): # noqa: F811
140140

141141
# method instead of function provided
142142
with pytest.raises(ValueError):
143-
SerializableCustomTransform(serializable_forward_fn=A.fn, serializable_inverse_fn=registered_fn)
144-
SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=A.fn)
143+
SerializableCustomTransform(forward=A.fn, inverse=registered_fn)
144+
with pytest.raises(ValueError):
145+
SerializableCustomTransform(forward=registered_fn, inverse=A.fn)
145146

146147
# lambda function provided
147148
with pytest.raises(ValueError):
148-
SerializableCustomTransform(serializable_forward_fn=lambda x: x, serializable_inverse_fn=registered_fn)
149-
SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=lambda x: x)
149+
SerializableCustomTransform(forward=lambda x: x, inverse=registered_fn)
150+
with pytest.raises(ValueError):
151+
SerializableCustomTransform(forward=registered_fn, inverse=lambda x: x)
150152

151153
# unregistered function provided
152154
with pytest.raises(ValueError):
153-
SerializableCustomTransform(serializable_forward_fn=not_registered_fn, serializable_inverse_fn=registered_fn)
154-
SerializableCustomTransform(serializable_forward_fn=registered_fn, serializable_inverse_fn=not_registered_fn)
155+
SerializableCustomTransform(forward=not_registered_fn, inverse=registered_fn)
156+
with pytest.raises(ValueError):
157+
SerializableCustomTransform(forward=registered_fn, inverse=not_registered_fn)
155158

156159
# function does not match registered function
157160
with pytest.raises(ValueError):
158-
SerializableCustomTransform(
159-
serializable_forward_fn=registered_but_changed, serializable_inverse_fn=registered_fn
160-
)
161-
SerializableCustomTransform(
162-
serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_but_changed
163-
)
164-
165-
transform = SerializableCustomTransform(
166-
serializable_forward_fn=registered_fn, serializable_inverse_fn=registered_fn
167-
)
161+
SerializableCustomTransform(forward=registered_but_changed, inverse=registered_fn)
162+
with pytest.raises(ValueError):
163+
SerializableCustomTransform(forward=registered_fn, inverse=registered_but_changed)
164+
165+
transform = SerializableCustomTransform(forward=registered_fn, inverse=registered_fn)
168166
serialized_transform = keras.saving.serialize_keras_object(transform)
169167
keras.saving.deserialize_keras_object(serialized_transform)
170168

0 commit comments

Comments
 (0)