Skip to content

Commit 8426dd1

Browse files
authored
Numpy transforms (#350)
* add numpy transform * remove lambda transform * repurpose `adapter.apply()` for `NumpyTransform` instead of `LambdaTransform` * fix usages of `adapter.apply()` in example notebooks * fix tests * only allow strings as arguments (subject to be fixed by #323) * unify serialization pattern * remove old lambda transform error message * add fail fast to CI tests * cannot suppport filtering kwargs for numpy ufunc in python 3.10
1 parent 74b9673 commit 8426dd1

File tree

15 files changed

+107
-111
lines changed

15 files changed

+107
-111
lines changed

.github/workflows/tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ jobs:
9595
9696
- name: Run Tests
9797
run: |
98-
pytest
98+
pytest -x
9999
100100
- name: Create Coverage Report
101101
run: |

bayesflow/adapters/adapter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Callable, MutableSequence, Sequence
1+
from collections.abc import MutableSequence, Sequence
22

33
import numpy as np
44
from keras.saving import (
@@ -18,9 +18,9 @@
1818
ExpandDims,
1919
FilterTransform,
2020
Keep,
21-
LambdaTransform,
2221
Log,
2322
MapTransform,
23+
NumpyTransform,
2424
OneHot,
2525
Rename,
2626
Sqrt,
@@ -234,8 +234,8 @@ def __len__(self):
234234
def apply(
235235
self,
236236
*,
237-
forward: Callable[[np.ndarray, ...], np.ndarray],
238-
inverse: Callable[[np.ndarray, ...], np.ndarray],
237+
forward: np.ufunc | str,
238+
inverse: np.ufunc | str = None,
239239
predicate: Predicate = None,
240240
include: str | Sequence[str] = None,
241241
exclude: str | Sequence[str] = None,
@@ -271,7 +271,7 @@ def apply(
271271
to the `custom_objects` argument of the `deserialize` function when deserializing this class.
272272
"""
273273
transform = FilterTransform(
274-
transform_constructor=LambdaTransform,
274+
transform_constructor=NumpyTransform,
275275
predicate=predicate,
276276
include=include,
277277
exclude=exclude,

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from .expand_dims import ExpandDims
1010
from .filter_transform import FilterTransform
1111
from .keep import Keep
12-
from .lambda_transform import LambdaTransform
1312
from .log import Log
1413
from .map_transform import MapTransform
14+
from .numpy_transform import NumpyTransform
1515
from .one_hot import OneHot
1616
from .rename import Rename
1717
from .sqrt import Sqrt

bayesflow/adapters/transforms/as_set.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
from keras.saving import register_keras_serializable as serializable
23

34
from .elementwise_transform import ElementwiseTransform
45

56

7+
@serializable(package="bayesflow.adapters")
68
class AsSet(ElementwiseTransform):
79
"""The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets.
810

bayesflow/adapters/transforms/as_time_series.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
from keras.saving import register_keras_serializable as serializable
23

34
from .elementwise_transform import ElementwiseTransform
45

56

7+
@serializable(package="bayesflow.adapters")
68
class AsTimeSeries(ElementwiseTransform):
79
"""The `.as_time_series` transform can be used to indicate that variables shall be treated as time series.
810

bayesflow/adapters/transforms/expand_dims.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import numpy as np
2-
32
from keras.saving import (
43
deserialize_keras_object as deserialize,
4+
register_keras_serializable as serializable,
55
serialize_keras_object as serialize,
66
)
77

88
from .elementwise_transform import ElementwiseTransform
99

1010

11+
@serializable(package="bayesflow.adapters")
1112
class ExpandDims(ElementwiseTransform):
1213
"""
1314
Expand the shape of an array.

bayesflow/adapters/transforms/filter_transform.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,6 @@ def from_config(cls, config: dict, custom_objects=None) -> "Transform":
8686
try:
8787
kwargs = deserialize(config["kwargs"])
8888
except TypeError as e:
89-
if transform_constructor.__name__ == "LambdaTransform":
90-
raise TypeError(
91-
"LambdaTransform (created by Adapter.apply) could not be deserialized.\n"
92-
"This is probably because the custom transform functions `forward` and "
93-
"`backward` from `Adapter.apply` were not passed as `custom_objects`.\n"
94-
"For example, if your adapter uses\n"
95-
"`Adapter.apply(forward=forward_transform, inverse=inverse_transform)`,\n"
96-
"you have to pass\n"
97-
'`custom_objects={"forward_transform": forward_transform, '
98-
'"inverse_transform": inverse_transform}`\n'
99-
"to the function you use to load the serialized object."
100-
) from e
10189
raise TypeError(
10290
"The transform could not be deserialized properly. "
10391
"The most likely reason is that some classes or functions "

bayesflow/adapters/transforms/lambda_transform.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

bayesflow/adapters/transforms/log.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import numpy as np
2-
32
from keras.saving import (
43
deserialize_keras_object as deserialize,
4+
register_keras_serializable as serializable,
55
serialize_keras_object as serialize,
66
)
77

88
from .elementwise_transform import ElementwiseTransform
99

1010

11+
@serializable(package="bayesflow.adapters")
1112
class Log(ElementwiseTransform):
1213
"""Log transforms a variable.
1314
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy as np
2+
from keras.saving import register_keras_serializable as serializable
3+
4+
from .elementwise_transform import ElementwiseTransform
5+
6+
7+
@serializable(package="bayesflow.adapters")
8+
class NumpyTransform(ElementwiseTransform):
9+
"""
10+
A class to apply element-wise transformations using plain NumPy functions.
11+
12+
Attributes:
13+
----------
14+
_forward : str
15+
The name of the NumPy function to apply in the forward transformation.
16+
_inverse : str
17+
The name of the NumPy function to apply in the inverse transformation.
18+
"""
19+
20+
INVERSE_METHODS = {
21+
np.arctan: np.tan,
22+
np.exp: np.log,
23+
np.expm1: np.log1p,
24+
np.square: np.sqrt,
25+
np.reciprocal: np.reciprocal,
26+
}
27+
# ensure the map is symmetric
28+
INVERSE_METHODS |= {v: k for k, v in INVERSE_METHODS.items()}
29+
30+
def __init__(self, forward: str, inverse: str = None):
31+
"""
32+
Initializes the NumpyTransform with specified forward and inverse functions.
33+
34+
Parameters:
35+
----------
36+
forward: str
37+
The name of the NumPy function to use for the forward transformation.
38+
inverse: str, optional
39+
The name of the NumPy function to use for the inverse transformation.
40+
By default, the inverse is inferred from the forward argument for supported methods.
41+
"""
42+
super().__init__()
43+
44+
if isinstance(forward, str):
45+
forward = getattr(np, forward)
46+
47+
if not isinstance(forward, np.ufunc):
48+
raise ValueError("Forward transformation must be a NumPy Universal Function (ufunc).")
49+
50+
if inverse is None:
51+
if forward not in self.INVERSE_METHODS:
52+
raise ValueError(f"Cannot infer inverse for method {forward!r}")
53+
54+
inverse = self.INVERSE_METHODS[forward]
55+
56+
if isinstance(inverse, str):
57+
inverse = getattr(np, inverse)
58+
59+
if not isinstance(inverse, np.ufunc):
60+
raise ValueError("Inverse transformation must be a NumPy Universal Function (ufunc).")
61+
62+
self._forward = forward
63+
self._inverse = inverse
64+
65+
@classmethod
66+
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
67+
return cls(
68+
forward=config["forward"],
69+
inverse=config["inverse"],
70+
)
71+
72+
def get_config(self) -> dict:
73+
return {"forward": self._forward.__name__, "inverse": self._inverse.__name__}
74+
75+
def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
76+
return self._forward(data)
77+
78+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
79+
return self._inverse(data)

0 commit comments

Comments
 (0)