Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,18 @@ def rename(self, from_key: str, to_key: str):
self.transforms.append(Rename(from_key, to_key))
return self

def scale(self, keys: str | Sequence[str], by: float | np.ndarray):
from .transforms import Scale

self.transforms.append(MapTransform({key: Scale(scale=by) for key in keys}))
Copy link

Copilot AI Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MapTransform is referenced but not imported. Please add an import statement for MapTransform to ensure it is defined.

Copilot uses AI. Check for mistakes.
return self

def shift(self, keys: str | Sequence[str], by: float | np.ndarray):
from .transforms import Shift

self.transforms.append(MapTransform({key: Shift(shift=by) for key in keys}))
return self

def sqrt(self, keys: str | Sequence[str]):
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.

Expand Down
2 changes: 2 additions & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .numpy_transform import NumpyTransform
from .one_hot import OneHot
from .rename import Rename
from .scale import Scale
from .shift import Shift
from .sqrt import Sqrt
from .standardize import Standardize
from .to_array import ToArray
Expand Down
27 changes: 27 additions & 0 deletions bayesflow/adapters/transforms/scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
import numpy as np

from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
class Scale(ElementwiseTransform):
def __init__(self, scale: np.typing.ArrayLike):
self.scale = np.array(scale)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
return cls(scale=deserialize(config["scale"]))

def get_config(self) -> dict:
return {"scale": serialize(self.scale)}

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data * self.scale

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data / self.scale
27 changes: 27 additions & 0 deletions bayesflow/adapters/transforms/shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
import numpy as np

from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
class Shift(ElementwiseTransform):
def __init__(self, shift: np.typing.ArrayLike):
self.shift = np.array(shift)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
return cls(shift=deserialize(config["shift"]))

def get_config(self) -> dict:
return {"shift": serialize(self.shift)}

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data + self.shift

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data - self.shift
2 changes: 2 additions & 0 deletions tests/test_adapters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def adapter():
.constrain("p2", lower=0)
.apply(include="p2", forward="exp", inverse="log")
.apply(include="p2", forward="log1p")
.scale("x", by=[-1, 2])
.shift("x", by=2)
.standardize(exclude=["t1", "t2", "o1"])
.drop("d1")
.one_hot("o1", 10)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,16 @@ def test_simple_transforms(random_data):

result = ad(random_data)

assert np.array_equal(result["p2"], np.log(random_data["p2"]))
assert np.array_equal(result["t2"], np.log(random_data["t2"]))
assert np.array_equal(result["t1"], np.log1p(random_data["t1"]))
assert np.array_equal(result["p1"], np.sqrt(random_data["p1"]))
assert np.allclose(result["p2"], np.log(random_data["p2"]))
assert np.allclose(result["t2"], np.log(random_data["t2"]))
assert np.allclose(result["t1"], np.log1p(random_data["t1"]))
assert np.allclose(result["p1"], np.sqrt(random_data["p1"]))

# inverse results should match the original input
inverse = ad(result, inverse=True)

assert np.array_equal(inverse["p2"], random_data["p2"])
assert np.array_equal(inverse["t2"], random_data["t2"])
assert np.array_equal(inverse["t1"], random_data["t1"])
# numerical inaccuries prevent np.array_equal to work here
assert np.allclose(inverse["p2"], random_data["p2"])
assert np.allclose(inverse["t2"], random_data["t2"])
assert np.allclose(inverse["t1"], random_data["t1"])

assert np.allclose(inverse["p1"], random_data["p1"])
Loading