Skip to content

Commit 512b323

Browse files
authored
Merge pull request #394 from bayesflow-org/adapter-scale-and-shift
Adapter scale and shift
2 parents 22c75d1 + f6c3b6d commit 512b323

File tree

6 files changed

+84
-8
lines changed

6 files changed

+84
-8
lines changed

bayesflow/adapters/adapter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,24 @@ def rename(self, from_key: str, to_key: str):
555555
self.transforms.append(Rename(from_key, to_key))
556556
return self
557557

558+
def scale(self, keys: str | Sequence[str], by: float | np.ndarray):
559+
from .transforms import Scale
560+
561+
if isinstance(keys, str):
562+
keys = [keys]
563+
564+
self.transforms.append(MapTransform({key: Scale(scale=by) for key in keys}))
565+
return self
566+
567+
def shift(self, keys: str | Sequence[str], by: float | np.ndarray):
568+
from .transforms import Shift
569+
570+
if isinstance(keys, str):
571+
keys = [keys]
572+
573+
self.transforms.append(MapTransform({key: Shift(shift=by) for key in keys}))
574+
return self
575+
558576
def sqrt(self, keys: str | Sequence[str]):
559577
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
560578

bayesflow/adapters/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from .numpy_transform import NumpyTransform
1515
from .one_hot import OneHot
1616
from .rename import Rename
17+
from .scale import Scale
18+
from .shift import Shift
1719
from .sqrt import Sqrt
1820
from .standardize import Standardize
1921
from .to_array import ToArray
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from keras.saving import (
2+
deserialize_keras_object as deserialize,
3+
register_keras_serializable as serializable,
4+
serialize_keras_object as serialize,
5+
)
6+
import numpy as np
7+
8+
from .elementwise_transform import ElementwiseTransform
9+
10+
11+
@serializable(package="bayesflow.adapters")
12+
class Scale(ElementwiseTransform):
13+
def __init__(self, scale: np.typing.ArrayLike):
14+
self.scale = np.array(scale)
15+
16+
@classmethod
17+
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
18+
return cls(scale=deserialize(config["scale"]))
19+
20+
def get_config(self) -> dict:
21+
return {"scale": serialize(self.scale)}
22+
23+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
24+
return data * self.scale
25+
26+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
27+
return data / self.scale
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from keras.saving import (
2+
deserialize_keras_object as deserialize,
3+
register_keras_serializable as serializable,
4+
serialize_keras_object as serialize,
5+
)
6+
import numpy as np
7+
8+
from .elementwise_transform import ElementwiseTransform
9+
10+
11+
@serializable(package="bayesflow.adapters")
12+
class Shift(ElementwiseTransform):
13+
def __init__(self, shift: np.typing.ArrayLike):
14+
self.shift = np.array(shift)
15+
16+
@classmethod
17+
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
18+
return cls(shift=deserialize(config["shift"]))
19+
20+
def get_config(self) -> dict:
21+
return {"shift": serialize(self.shift)}
22+
23+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
24+
return data + self.shift
25+
26+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
27+
return data - self.shift

tests/test_adapters/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def adapter():
2020
.constrain("p2", lower=0)
2121
.apply(include="p2", forward="exp", inverse="log")
2222
.apply(include="p2", forward="log1p")
23+
.scale("x", by=[-1, 2])
24+
.shift("x", by=2)
2325
.standardize(exclude=["t1", "t2", "o1"])
2426
.drop("d1")
2527
.one_hot("o1", 10)

tests/test_adapters/test_adapters.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,16 @@ def test_simple_transforms(random_data):
9797

9898
result = ad(random_data)
9999

100-
assert np.array_equal(result["p2"], np.log(random_data["p2"]))
101-
assert np.array_equal(result["t2"], np.log(random_data["t2"]))
102-
assert np.array_equal(result["t1"], np.log1p(random_data["t1"]))
103-
assert np.array_equal(result["p1"], np.sqrt(random_data["p1"]))
100+
assert np.allclose(result["p2"], np.log(random_data["p2"]))
101+
assert np.allclose(result["t2"], np.log(random_data["t2"]))
102+
assert np.allclose(result["t1"], np.log1p(random_data["t1"]))
103+
assert np.allclose(result["p1"], np.sqrt(random_data["p1"]))
104104

105105
# inverse results should match the original input
106106
inverse = ad(result, inverse=True)
107107

108-
assert np.array_equal(inverse["p2"], random_data["p2"])
109-
assert np.array_equal(inverse["t2"], random_data["t2"])
110-
assert np.array_equal(inverse["t1"], random_data["t1"])
111-
# numerical inaccuries prevent np.array_equal to work here
108+
assert np.allclose(inverse["p2"], random_data["p2"])
109+
assert np.allclose(inverse["t2"], random_data["t2"])
110+
assert np.allclose(inverse["t1"], random_data["t1"])
111+
112112
assert np.allclose(inverse["p1"], random_data["p1"])

0 commit comments

Comments
 (0)