Skip to content

Commit 53d48d4

Browse files
committed
only allow strings as arguments (subject to be fixed by #323)
1 parent 9ebab59 commit 53d48d4

File tree

2 files changed

+26
-25
lines changed

2 files changed

+26
-25
lines changed
Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2+
from keras.saving import register_keras_serializable as serializable
23

34
from bayesflow.utils import filter_kwargs
45
from .elementwise_transform import ElementwiseTransform
56

67

8+
@serializable(package="bayesflow.adapters")
79
class NumpyTransform(ElementwiseTransform):
810
"""
911
A class to apply element-wise transformations using plain NumPy functions.
@@ -17,45 +19,46 @@ class NumpyTransform(ElementwiseTransform):
1719
"""
1820

1921
INVERSE_METHODS = {
20-
"arctan": "tan",
21-
"exp": "log",
22-
"expm1": "log1p",
23-
"square": "sqrt",
24-
"reciprocal": "reciprocal",
22+
np.arctan: np.tan,
23+
np.exp: np.log,
24+
np.expm1: np.log1p,
25+
np.square: np.sqrt,
26+
np.reciprocal: np.reciprocal,
2527
}
2628
# ensure the map is symmetric
2729
INVERSE_METHODS |= {v: k for k, v in INVERSE_METHODS.items()}
2830

29-
def __init__(self, forward: np.ufunc | str, inverse: np.ufunc | str = None):
31+
def __init__(self, forward: str, inverse: str = None):
3032
"""
3133
Initializes the NumpyTransform with specified forward and inverse functions.
3234
3335
Parameters:
3436
----------
35-
forward : str
37+
forward: str
3638
The name of the NumPy function to use for the forward transformation.
37-
inverse : str
39+
inverse: str, optional
3840
The name of the NumPy function to use for the inverse transformation.
3941
By default, the inverse is inferred from the forward argument for supported methods.
4042
"""
4143
super().__init__()
4244

43-
if isinstance(forward, np.ufunc):
44-
forward = forward.__name__
45+
if isinstance(forward, str):
46+
forward = getattr(np, forward)
47+
48+
if not isinstance(forward, np.ufunc):
49+
raise ValueError("Forward transformation must be a NumPy Universal Function (ufunc).")
4550

4651
if inverse is None:
4752
if forward not in self.INVERSE_METHODS:
4853
raise ValueError(f"Cannot infer inverse for method {forward!r}")
4954

5055
inverse = self.INVERSE_METHODS[forward]
51-
elif isinstance(inverse, np.ufunc):
52-
inverse = inverse.__name__
5356

54-
if forward not in dir(np):
55-
raise ValueError(f"Method {forward!r} not found in numpy version {np.__version__}")
57+
if isinstance(inverse, str):
58+
inverse = getattr(np, inverse)
5659

57-
if inverse not in dir(np):
58-
raise ValueError(f"Method {inverse!r} not found in numpy version {np.__version__}")
60+
if not isinstance(inverse, np.ufunc):
61+
raise ValueError("Inverse transformation must be a NumPy Universal Function (ufunc).")
5962

6063
self._forward = forward
6164
self._inverse = inverse
@@ -68,14 +71,12 @@ def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform
6871
)
6972

7073
def get_config(self) -> dict:
71-
return {"forward": self._forward, "inverse": self._inverse}
74+
return {"forward": self._forward.__name__, "inverse": self._inverse.__name__}
7275

7376
def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
74-
forward = getattr(np, self._forward)
75-
kwargs = filter_kwargs(kwargs, forward)
76-
return forward(data, **kwargs)
77+
kwargs = filter_kwargs(kwargs, self._forward)
78+
return self._forward(data, **kwargs)
7779

7880
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
79-
inverse = getattr(np, self._inverse)
80-
kwargs = filter_kwargs(kwargs, inverse)
81-
return inverse(data, **kwargs)
81+
kwargs = filter_kwargs(kwargs, self._inverse)
82+
return self._inverse(data, **kwargs)

tests/test_adapters/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def adapter():
1818
.expand_dims(["z1"], axis=2)
1919
.log("p1")
2020
.constrain("p2", lower=0)
21-
.apply(include="p2", forward=np.exp, inverse=np.log)
22-
.apply(include="p2", forward="logp1")
21+
.apply(include="p2", forward="exp", inverse="log")
22+
.apply(include="p2", forward="log1p")
2323
.standardize(exclude=["t1", "t2", "o1"])
2424
.drop("d1")
2525
.one_hot("o1", 10)

0 commit comments

Comments
 (0)