diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index e84e309fe..ebef412bf 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -555,6 +555,24 @@ def rename(self, from_key: str, to_key: str): self.transforms.append(Rename(from_key, to_key)) return self + def scale(self, keys: str | Sequence[str], by: float | np.ndarray): + from .transforms import Scale + + if isinstance(keys, str): + keys = [keys] + + self.transforms.append(MapTransform({key: Scale(scale=by) for key in keys})) + return self + + def shift(self, keys: str | Sequence[str], by: float | np.ndarray): + from .transforms import Shift + + if isinstance(keys, str): + keys = [keys] + + self.transforms.append(MapTransform({key: Shift(shift=by) for key in keys})) + return self + def sqrt(self, keys: str | Sequence[str]): """Append an :py:class:`~transforms.Sqrt` transform to the adapter. diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index c0a1fce24..b3e95a494 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -14,6 +14,8 @@ from .numpy_transform import NumpyTransform from .one_hot import OneHot from .rename import Rename +from .scale import Scale +from .shift import Shift from .sqrt import Sqrt from .standardize import Standardize from .to_array import ToArray diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py new file mode 100644 index 000000000..3e33893db --- /dev/null +++ b/bayesflow/adapters/transforms/scale.py @@ -0,0 +1,27 @@ +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) +import numpy as np + +from .elementwise_transform import ElementwiseTransform + + +@serializable(package="bayesflow.adapters") +class Scale(ElementwiseTransform): + def __init__(self, scale: np.typing.ArrayLike): + self.scale = np.array(scale) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform": + return cls(scale=deserialize(config["scale"])) + + def get_config(self) -> dict: + return {"scale": serialize(self.scale)} + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + return data * self.scale + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return data / self.scale diff --git a/bayesflow/adapters/transforms/shift.py b/bayesflow/adapters/transforms/shift.py new file mode 100644 index 000000000..47d316027 --- /dev/null +++ b/bayesflow/adapters/transforms/shift.py @@ -0,0 +1,27 @@ +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) +import numpy as np + +from .elementwise_transform import ElementwiseTransform + + +@serializable(package="bayesflow.adapters") +class Shift(ElementwiseTransform): + def __init__(self, shift: np.typing.ArrayLike): + self.shift = np.array(shift) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform": + return cls(shift=deserialize(config["shift"])) + + def get_config(self) -> dict: + return {"shift": serialize(self.shift)} + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + return data + self.shift + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return data - self.shift diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index b020523d9..abd5797bd 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -20,6 +20,8 @@ def adapter(): .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log") .apply(include="p2", forward="log1p") + .scale("x", by=[-1, 2]) + .shift("x", by=2) .standardize(exclude=["t1", "t2", "o1"]) .drop("d1") .one_hot("o1", 10) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 840a71ba2..3dea0baf4 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -97,16 +97,16 @@ def test_simple_transforms(random_data): result = ad(random_data) - assert np.array_equal(result["p2"], np.log(random_data["p2"])) - assert np.array_equal(result["t2"], np.log(random_data["t2"])) - assert np.array_equal(result["t1"], np.log1p(random_data["t1"])) - assert np.array_equal(result["p1"], np.sqrt(random_data["p1"])) + assert np.allclose(result["p2"], np.log(random_data["p2"])) + assert np.allclose(result["t2"], np.log(random_data["t2"])) + assert np.allclose(result["t1"], np.log1p(random_data["t1"])) + assert np.allclose(result["p1"], np.sqrt(random_data["p1"])) # inverse results should match the original input inverse = ad(result, inverse=True) - assert np.array_equal(inverse["p2"], random_data["p2"]) - assert np.array_equal(inverse["t2"], random_data["t2"]) - assert np.array_equal(inverse["t1"], random_data["t1"]) - # numerical inaccuries prevent np.array_equal to work here + assert np.allclose(inverse["p2"], random_data["p2"]) + assert np.allclose(inverse["t2"], random_data["t2"]) + assert np.allclose(inverse["t1"], random_data["t1"]) + assert np.allclose(inverse["p1"], random_data["p1"])