From b95428d29a84399b1c2269ec4d1ea27878057998 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 8 Apr 2025 11:17:49 -0400 Subject: [PATCH 1/6] add scale and shift transforms --- bayesflow/adapters/adapter.py | 12 ++++++++++++ bayesflow/adapters/transforms/__init__.py | 2 ++ bayesflow/adapters/transforms/scale.py | 14 ++++++++++++++ bayesflow/adapters/transforms/shift.py | 16 ++++++++++++++++ 4 files changed, 44 insertions(+) create mode 100644 bayesflow/adapters/transforms/scale.py create mode 100644 bayesflow/adapters/transforms/shift.py diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index e84e309fe..a96fb8162 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -555,6 +555,18 @@ def rename(self, from_key: str, to_key: str): self.transforms.append(Rename(from_key, to_key)) return self + def scale(self, keys: str | Sequence[str], by: float | np.ndarray): + from .transforms import Scale + + self.transforms.append(MapTransform({key: Scale(scale=by) for key in keys})) + return self + + def shift(self, keys: str | Sequence[str], by: float | np.ndarray): + from .transforms import Shift + + self.transforms.append(MapTransform({key: Shift(shift=by) for key in keys})) + return self + def sqrt(self, keys: str | Sequence[str]): """Append an :py:class:`~transforms.Sqrt` transform to the adapter. diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index c0a1fce24..b3e95a494 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -14,6 +14,8 @@ from .numpy_transform import NumpyTransform from .one_hot import OneHot from .rename import Rename +from .scale import Scale +from .shift import Shift from .sqrt import Sqrt from .standardize import Standardize from .to_array import ToArray diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py new file mode 100644 index 000000000..71b5c08da --- /dev/null +++ b/bayesflow/adapters/transforms/scale.py @@ -0,0 +1,14 @@ +import numpy as np + +from .elementwise_transform import ElementwiseTransform + + +class Scale(ElementwiseTransform): + def __init__(self, scale: float | np.ndarray): + self.scale = scale + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + return data * self.scale + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return data / self.scale diff --git a/bayesflow/adapters/transforms/shift.py b/bayesflow/adapters/transforms/shift.py new file mode 100644 index 000000000..b5c099110 --- /dev/null +++ b/bayesflow/adapters/transforms/shift.py @@ -0,0 +1,16 @@ +import numpy as np + +from .elementwise_transform import ElementwiseTransform + +class Shift(ElementwiseTransform): + def __init__(self, shift: float | np.ndarray): + self.shift = shift + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + return data + self.shift + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return data - self.shift + + + From 1bfa9a59a19fed8b92fa8725ace43ef71ac45140 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 8 Apr 2025 11:20:34 -0400 Subject: [PATCH 2/6] add scale and shift to tests --- tests/test_adapters/conftest.py | 2 ++ tests/test_adapters/test_adapters.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index b020523d9..99a7a1797 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -20,6 +20,8 @@ def adapter(): .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log") .apply(include="p2", forward="log1p") + .scale("x", by=2) + .shift("x", by=2) .standardize(exclude=["t1", "t2", "o1"]) .drop("d1") .one_hot("o1", 10) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 840a71ba2..a2ea19cb8 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -108,5 +108,6 @@ def test_simple_transforms(random_data): assert np.array_equal(inverse["p2"], random_data["p2"]) assert np.array_equal(inverse["t2"], random_data["t2"]) assert np.array_equal(inverse["t1"], random_data["t1"]) - # numerical inaccuries prevent np.array_equal to work here + + # numerical inaccuracies prevent np.array_equal to work here assert np.allclose(inverse["p1"], random_data["p1"]) From e051a48aa6994c365d8d3892e879f614b7205918 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 8 Apr 2025 11:34:06 -0400 Subject: [PATCH 3/6] fix numerical accuracy in adapter test for simple transforms --- tests/test_adapters/test_adapters.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index a2ea19cb8..3dea0baf4 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -97,17 +97,16 @@ def test_simple_transforms(random_data): result = ad(random_data) - assert np.array_equal(result["p2"], np.log(random_data["p2"])) - assert np.array_equal(result["t2"], np.log(random_data["t2"])) - assert np.array_equal(result["t1"], np.log1p(random_data["t1"])) - assert np.array_equal(result["p1"], np.sqrt(random_data["p1"])) + assert np.allclose(result["p2"], np.log(random_data["p2"])) + assert np.allclose(result["t2"], np.log(random_data["t2"])) + assert np.allclose(result["t1"], np.log1p(random_data["t1"])) + assert np.allclose(result["p1"], np.sqrt(random_data["p1"])) # inverse results should match the original input inverse = ad(result, inverse=True) - assert np.array_equal(inverse["p2"], random_data["p2"]) - assert np.array_equal(inverse["t2"], random_data["t2"]) - assert np.array_equal(inverse["t1"], random_data["t1"]) + assert np.allclose(inverse["p2"], random_data["p2"]) + assert np.allclose(inverse["t2"], random_data["t2"]) + assert np.allclose(inverse["t1"], random_data["t1"]) - # numerical inaccuracies prevent np.array_equal to work here assert np.allclose(inverse["p1"], random_data["p1"]) From fa0d6c3618fb9cc84218cb040afe1c47a6acc1c8 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 8 Apr 2025 11:34:18 -0400 Subject: [PATCH 4/6] add array-like scale test --- tests/test_adapters/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index 99a7a1797..abd5797bd 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -20,7 +20,7 @@ def adapter(): .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log") .apply(include="p2", forward="log1p") - .scale("x", by=2) + .scale("x", by=[-1, 2]) .shift("x", by=2) .standardize(exclude=["t1", "t2", "o1"]) .drop("d1") From 0ddaad52e92763d904079341e67234fa9aac4c7a Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 8 Apr 2025 11:34:26 -0400 Subject: [PATCH 5/6] make scale and shift serializable --- bayesflow/adapters/transforms/scale.py | 17 +++++++++++++++-- bayesflow/adapters/transforms/shift.py | 21 ++++++++++++++++----- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py index 71b5c08da..3e33893db 100644 --- a/bayesflow/adapters/transforms/scale.py +++ b/bayesflow/adapters/transforms/scale.py @@ -1,11 +1,24 @@ +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) import numpy as np from .elementwise_transform import ElementwiseTransform +@serializable(package="bayesflow.adapters") class Scale(ElementwiseTransform): - def __init__(self, scale: float | np.ndarray): - self.scale = scale + def __init__(self, scale: np.typing.ArrayLike): + self.scale = np.array(scale) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform": + return cls(scale=deserialize(config["scale"])) + + def get_config(self) -> dict: + return {"scale": serialize(self.scale)} def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: return data * self.scale diff --git a/bayesflow/adapters/transforms/shift.py b/bayesflow/adapters/transforms/shift.py index b5c099110..47d316027 100644 --- a/bayesflow/adapters/transforms/shift.py +++ b/bayesflow/adapters/transforms/shift.py @@ -1,16 +1,27 @@ +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) import numpy as np from .elementwise_transform import ElementwiseTransform + +@serializable(package="bayesflow.adapters") class Shift(ElementwiseTransform): - def __init__(self, shift: float | np.ndarray): - self.shift = shift + def __init__(self, shift: np.typing.ArrayLike): + self.shift = np.array(shift) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform": + return cls(shift=deserialize(config["shift"])) + + def get_config(self) -> dict: + return {"shift": serialize(self.shift)} def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: return data + self.shift def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return data - self.shift - - - From f6c3b6d4d29147d87bdd521f2b4f8e6359574389 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 8 Apr 2025 11:41:27 -0400 Subject: [PATCH 6/6] fix scale and shift dispatch methods for single string input keys --- bayesflow/adapters/adapter.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index a96fb8162..ebef412bf 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -558,12 +558,18 @@ def rename(self, from_key: str, to_key: str): def scale(self, keys: str | Sequence[str], by: float | np.ndarray): from .transforms import Scale + if isinstance(keys, str): + keys = [keys] + self.transforms.append(MapTransform({key: Scale(scale=by) for key in keys})) return self def shift(self, keys: str | Sequence[str], by: float | np.ndarray): from .transforms import Shift + if isinstance(keys, str): + keys = [keys] + self.transforms.append(MapTransform({key: Shift(shift=by) for key in keys})) return self