diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 6e8175fb8..bfd5583cd 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -95,7 +95,7 @@ jobs: - name: Run Tests run: | - pytest + pytest -x - name: Create Coverage Report run: | diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 724be81bf..1fd303959 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -1,4 +1,4 @@ -from collections.abc import Callable, MutableSequence, Sequence +from collections.abc import MutableSequence, Sequence import numpy as np from keras.saving import ( @@ -18,9 +18,9 @@ ExpandDims, FilterTransform, Keep, - LambdaTransform, Log, MapTransform, + NumpyTransform, OneHot, Rename, Sqrt, @@ -234,8 +234,8 @@ def __len__(self): def apply( self, *, - forward: Callable[[np.ndarray, ...], np.ndarray], - inverse: Callable[[np.ndarray, ...], np.ndarray], + forward: np.ufunc | str, + inverse: np.ufunc | str = None, predicate: Predicate = None, include: str | Sequence[str] = None, exclude: str | Sequence[str] = None, @@ -271,7 +271,7 @@ def apply( to the `custom_objects` argument of the `deserialize` function when deserializing this class. """ transform = FilterTransform( - transform_constructor=LambdaTransform, + transform_constructor=NumpyTransform, predicate=predicate, include=include, exclude=exclude, diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index 1c5211d51..c0a1fce24 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -9,9 +9,9 @@ from .expand_dims import ExpandDims from .filter_transform import FilterTransform from .keep import Keep -from .lambda_transform import LambdaTransform from .log import Log from .map_transform import MapTransform +from .numpy_transform import NumpyTransform from .one_hot import OneHot from .rename import Rename from .sqrt import Sqrt diff --git a/bayesflow/adapters/transforms/as_set.py b/bayesflow/adapters/transforms/as_set.py index 2e0fe86e1..6e8e5567e 100644 --- a/bayesflow/adapters/transforms/as_set.py +++ b/bayesflow/adapters/transforms/as_set.py @@ -1,8 +1,10 @@ import numpy as np +from keras.saving import register_keras_serializable as serializable from .elementwise_transform import ElementwiseTransform +@serializable(package="bayesflow.adapters") class AsSet(ElementwiseTransform): """The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets. diff --git a/bayesflow/adapters/transforms/as_time_series.py b/bayesflow/adapters/transforms/as_time_series.py index 201181547..9453af817 100644 --- a/bayesflow/adapters/transforms/as_time_series.py +++ b/bayesflow/adapters/transforms/as_time_series.py @@ -1,8 +1,10 @@ import numpy as np +from keras.saving import register_keras_serializable as serializable from .elementwise_transform import ElementwiseTransform +@serializable(package="bayesflow.adapters") class AsTimeSeries(ElementwiseTransform): """The `.as_time_series` transform can be used to indicate that variables shall be treated as time series. diff --git a/bayesflow/adapters/transforms/expand_dims.py b/bayesflow/adapters/transforms/expand_dims.py index 6a9519d8e..eb4d712f4 100644 --- a/bayesflow/adapters/transforms/expand_dims.py +++ b/bayesflow/adapters/transforms/expand_dims.py @@ -1,13 +1,14 @@ import numpy as np - from keras.saving import ( deserialize_keras_object as deserialize, + register_keras_serializable as serializable, serialize_keras_object as serialize, ) from .elementwise_transform import ElementwiseTransform +@serializable(package="bayesflow.adapters") class ExpandDims(ElementwiseTransform): """ Expand the shape of an array. diff --git a/bayesflow/adapters/transforms/filter_transform.py b/bayesflow/adapters/transforms/filter_transform.py index 1a3d40b7a..65f5750c8 100644 --- a/bayesflow/adapters/transforms/filter_transform.py +++ b/bayesflow/adapters/transforms/filter_transform.py @@ -86,18 +86,6 @@ def from_config(cls, config: dict, custom_objects=None) -> "Transform": try: kwargs = deserialize(config["kwargs"]) except TypeError as e: - if transform_constructor.__name__ == "LambdaTransform": - raise TypeError( - "LambdaTransform (created by Adapter.apply) could not be deserialized.\n" - "This is probably because the custom transform functions `forward` and " - "`backward` from `Adapter.apply` were not passed as `custom_objects`.\n" - "For example, if your adapter uses\n" - "`Adapter.apply(forward=forward_transform, inverse=inverse_transform)`,\n" - "you have to pass\n" - '`custom_objects={"forward_transform": forward_transform, ' - '"inverse_transform": inverse_transform}`\n' - "to the function you use to load the serialized object." - ) from e raise TypeError( "The transform could not be deserialized properly. " "The most likely reason is that some classes or functions " diff --git a/bayesflow/adapters/transforms/lambda_transform.py b/bayesflow/adapters/transforms/lambda_transform.py deleted file mode 100644 index 91dc4c8a7..000000000 --- a/bayesflow/adapters/transforms/lambda_transform.py +++ /dev/null @@ -1,65 +0,0 @@ -from collections.abc import Callable -import numpy as np -from keras.saving import ( - deserialize_keras_object as deserialize, - register_keras_serializable as serializable, - serialize_keras_object as serialize, -) -from .elementwise_transform import ElementwiseTransform -from ...utils import filter_kwargs - - -@serializable(package="bayesflow.adapters") -class LambdaTransform(ElementwiseTransform): - """ - Transforms a parameter using a pair of forward and inverse functions. - - Parameters - ---------- - forward : callable, no lambda - Function to transform the data in the forward pass. - For the adapter to be serializable, this function has to be serializable - as well (see Notes). Therefore, only proper functions and no lambda - functions should be used here. - inverse : callable, no lambda - Function to transform the data in the inverse pass. - For the adapter to be serializable, this function has to be serializable - as well (see Notes). Therefore, only proper functions and no lambda - functions should be used here. - - Notes - ----- - Important: This class is only serializable if the forward and inverse functions are serializable. - This most likely means you will have to pass the scope that the forward and inverse functions are contained in - to the `custom_objects` argument of the `deserialize` function when deserializing this class. - """ - - def __init__( - self, *, forward: Callable[[np.ndarray, ...], np.ndarray], inverse: Callable[[np.ndarray, ...], np.ndarray] - ): - super().__init__() - - self._forward = forward - self._inverse = inverse - - @classmethod - def from_config(cls, config: dict, custom_objects=None) -> "LambdaTransform": - return cls( - forward=deserialize(config["forward"], custom_objects), - inverse=deserialize(config["inverse"], custom_objects), - ) - - def get_config(self) -> dict: - return { - "forward": serialize(self._forward), - "inverse": serialize(self._inverse), - } - - def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: - # filter kwargs so that other transform args like batch_size, strict, ... are not passed through - kwargs = filter_kwargs(kwargs, self._forward) - return self._forward(data, **kwargs) - - def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: - kwargs = filter_kwargs(kwargs, self._inverse) - return self._inverse(data, **kwargs) diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index cefe468b2..e264fccfa 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -1,13 +1,14 @@ import numpy as np - from keras.saving import ( deserialize_keras_object as deserialize, + register_keras_serializable as serializable, serialize_keras_object as serialize, ) from .elementwise_transform import ElementwiseTransform +@serializable(package="bayesflow.adapters") class Log(ElementwiseTransform): """Log transforms a variable. diff --git a/bayesflow/adapters/transforms/numpy_transform.py b/bayesflow/adapters/transforms/numpy_transform.py new file mode 100644 index 000000000..817d98b17 --- /dev/null +++ b/bayesflow/adapters/transforms/numpy_transform.py @@ -0,0 +1,79 @@ +import numpy as np +from keras.saving import register_keras_serializable as serializable + +from .elementwise_transform import ElementwiseTransform + + +@serializable(package="bayesflow.adapters") +class NumpyTransform(ElementwiseTransform): + """ + A class to apply element-wise transformations using plain NumPy functions. + + Attributes: + ---------- + _forward : str + The name of the NumPy function to apply in the forward transformation. + _inverse : str + The name of the NumPy function to apply in the inverse transformation. + """ + + INVERSE_METHODS = { + np.arctan: np.tan, + np.exp: np.log, + np.expm1: np.log1p, + np.square: np.sqrt, + np.reciprocal: np.reciprocal, + } + # ensure the map is symmetric + INVERSE_METHODS |= {v: k for k, v in INVERSE_METHODS.items()} + + def __init__(self, forward: str, inverse: str = None): + """ + Initializes the NumpyTransform with specified forward and inverse functions. + + Parameters: + ---------- + forward: str + The name of the NumPy function to use for the forward transformation. + inverse: str, optional + The name of the NumPy function to use for the inverse transformation. + By default, the inverse is inferred from the forward argument for supported methods. + """ + super().__init__() + + if isinstance(forward, str): + forward = getattr(np, forward) + + if not isinstance(forward, np.ufunc): + raise ValueError("Forward transformation must be a NumPy Universal Function (ufunc).") + + if inverse is None: + if forward not in self.INVERSE_METHODS: + raise ValueError(f"Cannot infer inverse for method {forward!r}") + + inverse = self.INVERSE_METHODS[forward] + + if isinstance(inverse, str): + inverse = getattr(np, inverse) + + if not isinstance(inverse, np.ufunc): + raise ValueError("Inverse transformation must be a NumPy Universal Function (ufunc).") + + self._forward = forward + self._inverse = inverse + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform": + return cls( + forward=config["forward"], + inverse=config["inverse"], + ) + + def get_config(self) -> dict: + return {"forward": self._forward.__name__, "inverse": self._inverse.__name__} + + def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]: + return self._forward(data) + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return self._inverse(data) diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index 88bb81a08..b9f693e75 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -1,8 +1,10 @@ import numpy as np +from keras.saving import register_keras_serializable as serializable from .elementwise_transform import ElementwiseTransform +@serializable(package="bayesflow.adapters") class Sqrt(ElementwiseTransform): """Square-root transform a variable. diff --git a/bayesflow/adapters/transforms/to_array.py b/bayesflow/adapters/transforms/to_array.py index 393d31b6e..aefb51040 100644 --- a/bayesflow/adapters/transforms/to_array.py +++ b/bayesflow/adapters/transforms/to_array.py @@ -1,9 +1,7 @@ from numbers import Number import numpy as np -from keras.saving import ( - register_keras_serializable as serializable, -) +from keras.saving import register_keras_serializable as serializable from bayesflow.utils.io import deserialize_type, serialize_type from .elementwise_transform import ElementwiseTransform diff --git a/examples/SIR_Posterior_Estimation.ipynb b/examples/SIR_Posterior_Estimation.ipynb index 4a4b1786d..15e478a47 100644 --- a/examples/SIR_Posterior_Estimation.ipynb +++ b/examples/SIR_Posterior_Estimation.ipynb @@ -375,7 +375,7 @@ " # since all our variables are non-negative (zero or larger)\n", " # this .apply call ensures that the variables are transformed\n", " # to the unconstrained real space and can be back-transformed under the hood\n", - " .apply(forward=lambda x: np.log1p(x), inverse=lambda x: np.expm1(x))\n", + " .apply(forward=np.log1p)\n", ")" ] }, diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index 03f214578..b020523d9 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -2,19 +2,6 @@ import pytest -def forward_transform(x): - return x + 1 - - -def inverse_transform(x): - return x - 1 - - -@pytest.fixture() -def custom_objects(): - return dict(forward_transform=forward_transform, inverse_transform=inverse_transform) - - @pytest.fixture() def adapter(): from bayesflow.adapters import Adapter @@ -29,9 +16,10 @@ def adapter(): .concatenate(["x1", "x2"], into="x") .concatenate(["y1", "y2"], into="y") .expand_dims(["z1"], axis=2) - .apply(forward=forward_transform, inverse=inverse_transform) .log("p1") .constrain("p2", lower=0) + .apply(include="p2", forward="exp", inverse="log") + .apply(include="p2", forward="log1p") .standardize(exclude=["t1", "t2", "o1"]) .drop("d1") .one_hot("o1", 10) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index efd58bd6e..840a71ba2 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -17,10 +17,10 @@ def test_cycle_consistency(adapter, random_data): assert np.allclose(value, deprocessed[key]) -def test_serialize_deserialize(adapter, custom_objects, random_data): +def test_serialize_deserialize(adapter, random_data): processed = adapter(random_data) serialized = serialize(adapter) - deserialized = deserialize(serialized, custom_objects) + deserialized = deserialize(serialized) reserialized = serialize(deserialized) assert reserialized.keys() == serialized.keys() @@ -51,7 +51,7 @@ def test_constrain(): "x_both_disc2": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))), } - adapter = ( + ad = ( Adapter() .constrain("x_lower_cont", lower=0) .constrain("x_upper_cont", upper=0) @@ -66,7 +66,7 @@ def test_constrain(): with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) - result = adapter(data) + result = ad(data) # continuous variables should not have boundary issues assert result["x_lower_cont"].min() < 0.0 @@ -93,9 +93,9 @@ def test_simple_transforms(random_data): # check if simple transforms are applied correctly from bayesflow.adapters import Adapter - adapter = Adapter().log(["p2", "t2"]).log("t1", p1=True).sqrt("p1") + ad = Adapter().log(["p2", "t2"]).log("t1", p1=True).sqrt("p1") - result = adapter(random_data) + result = ad(random_data) assert np.array_equal(result["p2"], np.log(random_data["p2"])) assert np.array_equal(result["t2"], np.log(random_data["t2"])) @@ -103,7 +103,7 @@ def test_simple_transforms(random_data): assert np.array_equal(result["p1"], np.sqrt(random_data["p1"])) # inverse results should match the original input - inverse = adapter.inverse(result) + inverse = ad(result, inverse=True) assert np.array_equal(inverse["p2"], random_data["p2"]) assert np.array_equal(inverse["t2"], random_data["t2"])