diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 94f7396f2..62890f30b 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -19,9 +19,11 @@ FilterTransform, Keep, LambdaTransform, + Log, MapTransform, OneHot, Rename, + Sqrt, Standardize, ToArray, Transform, @@ -481,7 +483,7 @@ def expand_dims(self, keys: str | Sequence[str], *, axis: int | tuple): if isinstance(keys, str): keys = [keys] - transform = ExpandDims(keys, axis=axis) + transform = MapTransform({key: ExpandDims(axis=axis) for key in keys}) self.transforms.append(transform) return self @@ -500,6 +502,23 @@ def keep(self, keys: str | Sequence[str]): self.transforms.append(transform) return self + def log(self, keys: str | Sequence[str], *, p1: bool = False): + """Append an :py:class:`~transforms.Log` transform to the adapter. + + Parameters + ---------- + keys : str or Sequence of str + The names of the variables to transform. + p1 : boolean + Add 1 to the input before taking the logarithm? + """ + if isinstance(keys, str): + keys = [keys] + + transform = MapTransform({key: Log(p1=p1) for key in keys}) + self.transforms.append(transform) + return self + def one_hot(self, keys: str | Sequence[str], num_classes: int): """Append a :py:class:`~transforms.OneHot` transform to the adapter. @@ -530,6 +549,21 @@ def rename(self, from_key: str, to_key: str): self.transforms.append(Rename(from_key, to_key)) return self + def sqrt(self, keys: str | Sequence[str]): + """Append an :py:class:`~transforms.Sqrt` transform to the adapter. + + Parameters + ---------- + keys : str or Sequence of str + The names of the variables to transform. + """ + if isinstance(keys, str): + keys = [keys] + + transform = MapTransform({key: Sqrt() for key in keys}) + self.transforms.append(transform) + return self + def standardize( self, *, diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index 92e294ed4..1c5211d51 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -10,9 +10,11 @@ from .filter_transform import FilterTransform from .keep import Keep from .lambda_transform import LambdaTransform +from .log import Log from .map_transform import MapTransform from .one_hot import OneHot from .rename import Rename +from .sqrt import Sqrt from .standardize import Standardize from .to_array import ToArray from .transform import Transform diff --git a/bayesflow/adapters/transforms/expand_dims.py b/bayesflow/adapters/transforms/expand_dims.py index 25e75de82..6a9519d8e 100644 --- a/bayesflow/adapters/transforms/expand_dims.py +++ b/bayesflow/adapters/transforms/expand_dims.py @@ -5,7 +5,6 @@ serialize_keras_object as serialize, ) -from collections.abc import Sequence from .elementwise_transform import ElementwiseTransform @@ -15,8 +14,6 @@ class ExpandDims(ElementwiseTransform): Parameters ---------- - keys : str or Sequence of str - The names of the variables to expand. axis : int or tuple The axis to expand. @@ -49,29 +46,23 @@ class ExpandDims(ElementwiseTransform): It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform. """ - def __init__(self, keys: Sequence[str], *, axis: int | tuple): + def __init__(self, *, axis: int | tuple): super().__init__() - - self.keys = keys self.axis = axis @classmethod def from_config(cls, config: dict, custom_objects=None) -> "ExpandDims": return cls( - keys=deserialize(config["keys"], custom_objects), axis=deserialize(config["axis"], custom_objects), ) def get_config(self) -> dict: return { - "keys": serialize(self.keys), "axis": serialize(self.axis), } - # noinspection PyMethodOverriding - def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: - return {k: (np.expand_dims(v, axis=self.axis) if k in self.keys else v) for k, v in data.items()} + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.expand_dims(data, axis=self.axis) - # noinspection PyMethodOverriding - def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: - return {k: (np.squeeze(v, axis=self.axis) if k in self.keys else v) for k, v in data.items()} + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.squeeze(data, axis=self.axis) diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py new file mode 100644 index 000000000..cefe468b2 --- /dev/null +++ b/bayesflow/adapters/transforms/log.py @@ -0,0 +1,49 @@ +import numpy as np + +from keras.saving import ( + deserialize_keras_object as deserialize, + serialize_keras_object as serialize, +) + +from .elementwise_transform import ElementwiseTransform + + +class Log(ElementwiseTransform): + """Log transforms a variable. + + Parameters + ---------- + p1 : boolean + Add 1 to the input before taking the logarithm? + + Examples + -------- + >>> adapter = bf.Adapter().log(["x"]) + """ + + def __init__(self, *, p1: bool = False): + super().__init__() + self.p1 = p1 + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + if self.p1: + return np.log1p(data) + else: + return np.log(data) + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + if self.p1: + return np.expm1(data) + else: + return np.exp(data) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "Log": + return cls( + p1=deserialize(config["p1"], custom_objects), + ) + + def get_config(self) -> dict: + return { + "p1": serialize(self.p1), + } diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py new file mode 100644 index 000000000..88bb81a08 --- /dev/null +++ b/bayesflow/adapters/transforms/sqrt.py @@ -0,0 +1,25 @@ +import numpy as np + +from .elementwise_transform import ElementwiseTransform + + +class Sqrt(ElementwiseTransform): + """Square-root transform a variable. + + Examples + -------- + >>> adapter = bf.Adapter().sqrt(["x"]) + """ + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.sqrt(data) + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.square(data) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "Sqrt": + return cls() + + def get_config(self) -> dict: + return {} diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index 74db29530..03f214578 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -30,8 +30,7 @@ def adapter(): .concatenate(["y1", "y2"], into="y") .expand_dims(["z1"], axis=2) .apply(forward=forward_transform, inverse=inverse_transform) - # TODO: fix this in keras - # .apply(include="p1", forward=np.log, inverse=np.exp) + .log("p1") .constrain("p2", lower=0) .standardize(exclude=["t1", "t2", "o1"]) .drop("d1") diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 7247869d7..efd58bd6e 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -87,3 +87,26 @@ def test_constrain(): assert np.isinf(result["x_upper_disc2"][0]) assert np.isneginf(result["x_both_disc2"][0]) assert np.isinf(result["x_both_disc2"][-1]) + + +def test_simple_transforms(random_data): + # check if simple transforms are applied correctly + from bayesflow.adapters import Adapter + + adapter = Adapter().log(["p2", "t2"]).log("t1", p1=True).sqrt("p1") + + result = adapter(random_data) + + assert np.array_equal(result["p2"], np.log(random_data["p2"])) + assert np.array_equal(result["t2"], np.log(random_data["t2"])) + assert np.array_equal(result["t1"], np.log1p(random_data["t1"])) + assert np.array_equal(result["p1"], np.sqrt(random_data["p1"])) + + # inverse results should match the original input + inverse = adapter.inverse(result) + + assert np.array_equal(inverse["p2"], random_data["p2"]) + assert np.array_equal(inverse["t2"], random_data["t2"]) + assert np.array_equal(inverse["t1"], random_data["t1"]) + # numerical inaccuries prevent np.array_equal to work here + assert np.allclose(inverse["p1"], random_data["p1"])