From 108cec9ff503789e098012298aa61b9bae3b9993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Mon, 10 Mar 2025 11:41:43 +0100 Subject: [PATCH 1/5] adapter: support log and sqrt transforms --- bayesflow/adapters/adapter.py | 34 ++++++++++++++++++ bayesflow/adapters/transforms/__init__.py | 2 ++ bayesflow/adapters/transforms/log.py | 43 +++++++++++++++++++++++ bayesflow/adapters/transforms/sqrt.py | 31 ++++++++++++++++ tests/test_adapters/test_adapters.py | 18 ++++++++++ 5 files changed, 128 insertions(+) create mode 100644 bayesflow/adapters/transforms/log.py create mode 100644 bayesflow/adapters/transforms/sqrt.py diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 94f7396f2..edf7c405f 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -19,9 +19,11 @@ FilterTransform, Keep, LambdaTransform, + Log, MapTransform, OneHot, Rename, + Sqrt, Standardize, ToArray, Transform, @@ -500,6 +502,23 @@ def keep(self, keys: str | Sequence[str]): self.transforms.append(transform) return self + def log(self, keys: str | Sequence[str], *, p1: bool = False): + """Append an :py:class:`~transforms.Log` transform to the adapter. + + Parameters + ---------- + keys : str or Sequence of str + The names of the variables to expand. + p1 : boolean + Add 1 to the input before taking the logarithm? + """ + if isinstance(keys, str): + keys = [keys] + + transform = Log(keys, p1=p1) + self.transforms.append(transform) + return self + def one_hot(self, keys: str | Sequence[str], num_classes: int): """Append a :py:class:`~transforms.OneHot` transform to the adapter. @@ -530,6 +549,21 @@ def rename(self, from_key: str, to_key: str): self.transforms.append(Rename(from_key, to_key)) return self + def sqrt(self, keys: str | Sequence[str]): + """Append an :py:class:`~transforms.Sqrt` transform to the adapter. + + Parameters + ---------- + keys : str or Sequence of str + The names of the variables to expand. + """ + if isinstance(keys, str): + keys = [keys] + + transform = Sqrt(keys) + self.transforms.append(transform) + return self + def standardize( self, *, diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index 92e294ed4..1c5211d51 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -10,9 +10,11 @@ from .filter_transform import FilterTransform from .keep import Keep from .lambda_transform import LambdaTransform +from .log import Log from .map_transform import MapTransform from .one_hot import OneHot from .rename import Rename +from .sqrt import Sqrt from .standardize import Standardize from .to_array import ToArray from .transform import Transform diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py new file mode 100644 index 000000000..d458870e0 --- /dev/null +++ b/bayesflow/adapters/transforms/log.py @@ -0,0 +1,43 @@ +import numpy as np + +from collections.abc import Sequence + +from .elementwise_transform import ElementwiseTransform + + +class Log(ElementwiseTransform): + """Log transforms a variable. + + Parameters + ---------- + p1 : boolean + Add 1 to the input before taking the logarithm? + + Examples + -------- + >>> adapter = bf.Adapter().log(["x"]) + """ + + def __init__(self, keys: Sequence[str], *, p1: bool = False): + super().__init__() + self.keys = keys + self.p1 = p1 + + def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: + if self.p1: + return {k: (np.log1p(v) if k in self.keys else v) for k, v in data.items()} + else: + return {k: (np.log(v) if k in self.keys else v) for k, v in data.items()} + + def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: + if self.p1: + return {k: (np.expm1(v) if k in self.keys else v) for k, v in data.items()} + else: + return {k: (np.exp(v) if k in self.keys else v) for k, v in data.items()} + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "Log": + return cls() + + def get_config(self) -> dict: + return {} diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py new file mode 100644 index 000000000..f1f32183b --- /dev/null +++ b/bayesflow/adapters/transforms/sqrt.py @@ -0,0 +1,31 @@ +import numpy as np + +from collections.abc import Sequence + +from .elementwise_transform import ElementwiseTransform + + +class Sqrt(ElementwiseTransform): + """Square-root transform a variable. + + Examples + -------- + >>> adapter = bf.Adapter().sqrt(["x"]) + """ + + def __init__(self, keys: Sequence[str]): + super().__init__() + self.keys = keys + + def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: + return {k: (np.sqrt(v) if k in self.keys else v) for k, v in data.items()} + + def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: + return {k: (np.square(v) if k in self.keys else v) for k, v in data.items()} + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "Sqrt": + return cls() + + def get_config(self) -> dict: + return {} diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 7247869d7..0f4b23ad5 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -87,3 +87,21 @@ def test_constrain(): assert np.isinf(result["x_upper_disc2"][0]) assert np.isneginf(result["x_both_disc2"][0]) assert np.isinf(result["x_both_disc2"][-1]) + +def test_log_sqrt(random_data): + # check if constraint-implied transforms are applied correctly + from bayesflow.adapters import Adapter + + adapter = ( + Adapter() + .log(["o1", "p2"]) + .log("t1", p1=True) + .sqrt("p1") + ) + + result = adapter(random_data) + + assert np.isfinite(result["o1"][0, 0]) + assert np.isfinite(result["p2"][0, 0]) + assert np.isfinite(result["t1"][0, 0]) + assert np.isfinite(result["p1"][0, 0]) From fc60b396dee00a69b2b610fde7aaa859a237895f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Mon, 10 Mar 2025 12:14:43 +0100 Subject: [PATCH 2/5] rerun linter --- tests/test_adapters/test_adapters.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 0f4b23ad5..de112fe32 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -88,16 +88,12 @@ def test_constrain(): assert np.isneginf(result["x_both_disc2"][0]) assert np.isinf(result["x_both_disc2"][-1]) + def test_log_sqrt(random_data): # check if constraint-implied transforms are applied correctly from bayesflow.adapters import Adapter - adapter = ( - Adapter() - .log(["o1", "p2"]) - .log("t1", p1=True) - .sqrt("p1") - ) + adapter = Adapter().log(["o1", "p2"]).log("t1", p1=True).sqrt("p1") result = adapter(random_data) From 49b45d7891a55f6ff434b0d7415c8da4b6185437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Mon, 10 Mar 2025 20:05:11 +0100 Subject: [PATCH 3/5] use map transform --- bayesflow/adapters/adapter.py | 10 +++---- bayesflow/adapters/transforms/expand_dims.py | 19 ++++--------- bayesflow/adapters/transforms/log.py | 28 ++++++++++++-------- bayesflow/adapters/transforms/sqrt.py | 14 +++------- tests/test_adapters/conftest.py | 3 +-- tests/test_adapters/test_adapters.py | 4 +-- 6 files changed, 34 insertions(+), 44 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index edf7c405f..62890f30b 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -483,7 +483,7 @@ def expand_dims(self, keys: str | Sequence[str], *, axis: int | tuple): if isinstance(keys, str): keys = [keys] - transform = ExpandDims(keys, axis=axis) + transform = MapTransform({key: ExpandDims(axis=axis) for key in keys}) self.transforms.append(transform) return self @@ -508,14 +508,14 @@ def log(self, keys: str | Sequence[str], *, p1: bool = False): Parameters ---------- keys : str or Sequence of str - The names of the variables to expand. + The names of the variables to transform. p1 : boolean Add 1 to the input before taking the logarithm? """ if isinstance(keys, str): keys = [keys] - transform = Log(keys, p1=p1) + transform = MapTransform({key: Log(p1=p1) for key in keys}) self.transforms.append(transform) return self @@ -555,12 +555,12 @@ def sqrt(self, keys: str | Sequence[str]): Parameters ---------- keys : str or Sequence of str - The names of the variables to expand. + The names of the variables to transform. """ if isinstance(keys, str): keys = [keys] - transform = Sqrt(keys) + transform = MapTransform({key: Sqrt() for key in keys}) self.transforms.append(transform) return self diff --git a/bayesflow/adapters/transforms/expand_dims.py b/bayesflow/adapters/transforms/expand_dims.py index 25e75de82..6a9519d8e 100644 --- a/bayesflow/adapters/transforms/expand_dims.py +++ b/bayesflow/adapters/transforms/expand_dims.py @@ -5,7 +5,6 @@ serialize_keras_object as serialize, ) -from collections.abc import Sequence from .elementwise_transform import ElementwiseTransform @@ -15,8 +14,6 @@ class ExpandDims(ElementwiseTransform): Parameters ---------- - keys : str or Sequence of str - The names of the variables to expand. axis : int or tuple The axis to expand. @@ -49,29 +46,23 @@ class ExpandDims(ElementwiseTransform): It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform. """ - def __init__(self, keys: Sequence[str], *, axis: int | tuple): + def __init__(self, *, axis: int | tuple): super().__init__() - - self.keys = keys self.axis = axis @classmethod def from_config(cls, config: dict, custom_objects=None) -> "ExpandDims": return cls( - keys=deserialize(config["keys"], custom_objects), axis=deserialize(config["axis"], custom_objects), ) def get_config(self) -> dict: return { - "keys": serialize(self.keys), "axis": serialize(self.axis), } - # noinspection PyMethodOverriding - def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: - return {k: (np.expand_dims(v, axis=self.axis) if k in self.keys else v) for k, v in data.items()} + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.expand_dims(data, axis=self.axis) - # noinspection PyMethodOverriding - def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: - return {k: (np.squeeze(v, axis=self.axis) if k in self.keys else v) for k, v in data.items()} + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.squeeze(data, axis=self.axis) diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index d458870e0..cefe468b2 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -1,6 +1,9 @@ import numpy as np -from collections.abc import Sequence +from keras.saving import ( + deserialize_keras_object as deserialize, + serialize_keras_object as serialize, +) from .elementwise_transform import ElementwiseTransform @@ -18,26 +21,29 @@ class Log(ElementwiseTransform): >>> adapter = bf.Adapter().log(["x"]) """ - def __init__(self, keys: Sequence[str], *, p1: bool = False): + def __init__(self, *, p1: bool = False): super().__init__() - self.keys = keys self.p1 = p1 - def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: if self.p1: - return {k: (np.log1p(v) if k in self.keys else v) for k, v in data.items()} + return np.log1p(data) else: - return {k: (np.log(v) if k in self.keys else v) for k, v in data.items()} + return np.log(data) - def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: if self.p1: - return {k: (np.expm1(v) if k in self.keys else v) for k, v in data.items()} + return np.expm1(data) else: - return {k: (np.exp(v) if k in self.keys else v) for k, v in data.items()} + return np.exp(data) @classmethod def from_config(cls, config: dict, custom_objects=None) -> "Log": - return cls() + return cls( + p1=deserialize(config["p1"], custom_objects), + ) def get_config(self) -> dict: - return {} + return { + "p1": serialize(self.p1), + } diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index f1f32183b..88bb81a08 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -1,7 +1,5 @@ import numpy as np -from collections.abc import Sequence - from .elementwise_transform import ElementwiseTransform @@ -13,15 +11,11 @@ class Sqrt(ElementwiseTransform): >>> adapter = bf.Adapter().sqrt(["x"]) """ - def __init__(self, keys: Sequence[str]): - super().__init__() - self.keys = keys - - def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: - return {k: (np.sqrt(v) if k in self.keys else v) for k, v in data.items()} + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.sqrt(data) - def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: - return {k: (np.square(v) if k in self.keys else v) for k, v in data.items()} + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.square(data) @classmethod def from_config(cls, config: dict, custom_objects=None) -> "Sqrt": diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index 74db29530..03f214578 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -30,8 +30,7 @@ def adapter(): .concatenate(["y1", "y2"], into="y") .expand_dims(["z1"], axis=2) .apply(forward=forward_transform, inverse=inverse_transform) - # TODO: fix this in keras - # .apply(include="p1", forward=np.log, inverse=np.exp) + .log("p1") .constrain("p2", lower=0) .standardize(exclude=["t1", "t2", "o1"]) .drop("d1") diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index de112fe32..ebac89063 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -89,8 +89,8 @@ def test_constrain(): assert np.isinf(result["x_both_disc2"][-1]) -def test_log_sqrt(random_data): - # check if constraint-implied transforms are applied correctly +def test_simple_transforms(random_data): + # check if simple transforms are applied correctly from bayesflow.adapters import Adapter adapter = Adapter().log(["o1", "p2"]).log("t1", p1=True).sqrt("p1") From 7091f01c5fd17938d6ee5df7a89f5fc2082c4152 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Tue, 11 Mar 2025 09:00:42 +0100 Subject: [PATCH 4/5] improve adapter tests --- tests/test_adapters/test_adapters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index ebac89063..5f633730d 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -97,7 +97,7 @@ def test_simple_transforms(random_data): result = adapter(random_data) - assert np.isfinite(result["o1"][0, 0]) - assert np.isfinite(result["p2"][0, 0]) - assert np.isfinite(result["t1"][0, 0]) - assert np.isfinite(result["p1"][0, 0]) + assert np.array_equal(result["o1"], np.log(random_data["o1"])) + assert np.array_equal(result["p2"], np.log(random_data["p2"])) + assert np.array_equal(result["t1"], np.log1p(random_data["t1"])) + assert np.array_equal(result["p1"], np.sqrt(random_data["p1"])) From 5e2da0f6c0b58fc2deea7a7e3e15d25096badb91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Tue, 11 Mar 2025 09:41:08 +0100 Subject: [PATCH 5/5] also test inverse of simple transforms --- tests/test_adapters/test_adapters.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 5f633730d..efd58bd6e 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -93,11 +93,20 @@ def test_simple_transforms(random_data): # check if simple transforms are applied correctly from bayesflow.adapters import Adapter - adapter = Adapter().log(["o1", "p2"]).log("t1", p1=True).sqrt("p1") + adapter = Adapter().log(["p2", "t2"]).log("t1", p1=True).sqrt("p1") result = adapter(random_data) - assert np.array_equal(result["o1"], np.log(random_data["o1"])) assert np.array_equal(result["p2"], np.log(random_data["p2"])) + assert np.array_equal(result["t2"], np.log(random_data["t2"])) assert np.array_equal(result["t1"], np.log1p(random_data["t1"])) assert np.array_equal(result["p1"], np.sqrt(random_data["p1"])) + + # inverse results should match the original input + inverse = adapter.inverse(result) + + assert np.array_equal(inverse["p2"], random_data["p2"]) + assert np.array_equal(inverse["t2"], random_data["t2"]) + assert np.array_equal(inverse["t1"], random_data["t1"]) + # numerical inaccuries prevent np.array_equal to work here + assert np.allclose(inverse["p1"], random_data["p1"])