Skip to content

Commit 7537ee4

Browse files
Merge pull request #348 from bayesflow-org/log_sqrt_transforms
adapter: support log and sqrt transforms
2 parents 1dfdf0a + 5e2da0f commit 7537ee4

File tree

7 files changed

+140
-17
lines changed

7 files changed

+140
-17
lines changed

bayesflow/adapters/adapter.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
FilterTransform,
2020
Keep,
2121
LambdaTransform,
22+
Log,
2223
MapTransform,
2324
OneHot,
2425
Rename,
26+
Sqrt,
2527
Standardize,
2628
ToArray,
2729
Transform,
@@ -487,7 +489,7 @@ def expand_dims(self, keys: str | Sequence[str], *, axis: int | tuple):
487489
if isinstance(keys, str):
488490
keys = [keys]
489491

490-
transform = ExpandDims(keys, axis=axis)
492+
transform = MapTransform({key: ExpandDims(axis=axis) for key in keys})
491493
self.transforms.append(transform)
492494
return self
493495

@@ -506,6 +508,23 @@ def keep(self, keys: str | Sequence[str]):
506508
self.transforms.append(transform)
507509
return self
508510

511+
def log(self, keys: str | Sequence[str], *, p1: bool = False):
512+
"""Append an :py:class:`~transforms.Log` transform to the adapter.
513+
514+
Parameters
515+
----------
516+
keys : str or Sequence of str
517+
The names of the variables to transform.
518+
p1 : boolean
519+
Add 1 to the input before taking the logarithm?
520+
"""
521+
if isinstance(keys, str):
522+
keys = [keys]
523+
524+
transform = MapTransform({key: Log(p1=p1) for key in keys})
525+
self.transforms.append(transform)
526+
return self
527+
509528
def one_hot(self, keys: str | Sequence[str], num_classes: int):
510529
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.
511530
@@ -536,6 +555,21 @@ def rename(self, from_key: str, to_key: str):
536555
self.transforms.append(Rename(from_key, to_key))
537556
return self
538557

558+
def sqrt(self, keys: str | Sequence[str]):
559+
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
560+
561+
Parameters
562+
----------
563+
keys : str or Sequence of str
564+
The names of the variables to transform.
565+
"""
566+
if isinstance(keys, str):
567+
keys = [keys]
568+
569+
transform = MapTransform({key: Sqrt() for key in keys})
570+
self.transforms.append(transform)
571+
return self
572+
539573
def standardize(
540574
self,
541575
*,

bayesflow/adapters/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from .filter_transform import FilterTransform
1111
from .keep import Keep
1212
from .lambda_transform import LambdaTransform
13+
from .log import Log
1314
from .map_transform import MapTransform
1415
from .one_hot import OneHot
1516
from .rename import Rename
17+
from .sqrt import Sqrt
1618
from .standardize import Standardize
1719
from .to_array import ToArray
1820
from .transform import Transform

bayesflow/adapters/transforms/expand_dims.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
serialize_keras_object as serialize,
66
)
77

8-
from collections.abc import Sequence
98
from .elementwise_transform import ElementwiseTransform
109

1110

@@ -15,8 +14,6 @@ class ExpandDims(ElementwiseTransform):
1514
1615
Parameters
1716
----------
18-
keys : str or Sequence of str
19-
The names of the variables to expand.
2017
axis : int or tuple
2118
The axis to expand.
2219
@@ -49,29 +46,23 @@ class ExpandDims(ElementwiseTransform):
4946
It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform.
5047
"""
5148

52-
def __init__(self, keys: Sequence[str], *, axis: int | tuple):
49+
def __init__(self, *, axis: int | tuple):
5350
super().__init__()
54-
55-
self.keys = keys
5651
self.axis = axis
5752

5853
@classmethod
5954
def from_config(cls, config: dict, custom_objects=None) -> "ExpandDims":
6055
return cls(
61-
keys=deserialize(config["keys"], custom_objects),
6256
axis=deserialize(config["axis"], custom_objects),
6357
)
6458

6559
def get_config(self) -> dict:
6660
return {
67-
"keys": serialize(self.keys),
6861
"axis": serialize(self.axis),
6962
}
7063

71-
# noinspection PyMethodOverriding
72-
def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
73-
return {k: (np.expand_dims(v, axis=self.axis) if k in self.keys else v) for k, v in data.items()}
64+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
65+
return np.expand_dims(data, axis=self.axis)
7466

75-
# noinspection PyMethodOverriding
76-
def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
77-
return {k: (np.squeeze(v, axis=self.axis) if k in self.keys else v) for k, v in data.items()}
67+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
68+
return np.squeeze(data, axis=self.axis)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import numpy as np
2+
3+
from keras.saving import (
4+
deserialize_keras_object as deserialize,
5+
serialize_keras_object as serialize,
6+
)
7+
8+
from .elementwise_transform import ElementwiseTransform
9+
10+
11+
class Log(ElementwiseTransform):
12+
"""Log transforms a variable.
13+
14+
Parameters
15+
----------
16+
p1 : boolean
17+
Add 1 to the input before taking the logarithm?
18+
19+
Examples
20+
--------
21+
>>> adapter = bf.Adapter().log(["x"])
22+
"""
23+
24+
def __init__(self, *, p1: bool = False):
25+
super().__init__()
26+
self.p1 = p1
27+
28+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
29+
if self.p1:
30+
return np.log1p(data)
31+
else:
32+
return np.log(data)
33+
34+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
35+
if self.p1:
36+
return np.expm1(data)
37+
else:
38+
return np.exp(data)
39+
40+
@classmethod
41+
def from_config(cls, config: dict, custom_objects=None) -> "Log":
42+
return cls(
43+
p1=deserialize(config["p1"], custom_objects),
44+
)
45+
46+
def get_config(self) -> dict:
47+
return {
48+
"p1": serialize(self.p1),
49+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
3+
from .elementwise_transform import ElementwiseTransform
4+
5+
6+
class Sqrt(ElementwiseTransform):
7+
"""Square-root transform a variable.
8+
9+
Examples
10+
--------
11+
>>> adapter = bf.Adapter().sqrt(["x"])
12+
"""
13+
14+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
15+
return np.sqrt(data)
16+
17+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
18+
return np.square(data)
19+
20+
@classmethod
21+
def from_config(cls, config: dict, custom_objects=None) -> "Sqrt":
22+
return cls()
23+
24+
def get_config(self) -> dict:
25+
return {}

tests/test_adapters/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ def adapter():
3030
.concatenate(["y1", "y2"], into="y")
3131
.expand_dims(["z1"], axis=2)
3232
.apply(forward=forward_transform, inverse=inverse_transform)
33-
# TODO: fix this in keras
34-
# .apply(include="p1", forward=np.log, inverse=np.exp)
33+
.log("p1")
3534
.constrain("p2", lower=0)
3635
.standardize(exclude=["t1", "t2", "o1"])
3736
.drop("d1")

tests/test_adapters/test_adapters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,26 @@ def test_constrain():
8787
assert np.isinf(result["x_upper_disc2"][0])
8888
assert np.isneginf(result["x_both_disc2"][0])
8989
assert np.isinf(result["x_both_disc2"][-1])
90+
91+
92+
def test_simple_transforms(random_data):
93+
# check if simple transforms are applied correctly
94+
from bayesflow.adapters import Adapter
95+
96+
adapter = Adapter().log(["p2", "t2"]).log("t1", p1=True).sqrt("p1")
97+
98+
result = adapter(random_data)
99+
100+
assert np.array_equal(result["p2"], np.log(random_data["p2"]))
101+
assert np.array_equal(result["t2"], np.log(random_data["t2"]))
102+
assert np.array_equal(result["t1"], np.log1p(random_data["t1"]))
103+
assert np.array_equal(result["p1"], np.sqrt(random_data["p1"]))
104+
105+
# inverse results should match the original input
106+
inverse = adapter.inverse(result)
107+
108+
assert np.array_equal(inverse["p2"], random_data["p2"])
109+
assert np.array_equal(inverse["t2"], random_data["t2"])
110+
assert np.array_equal(inverse["t1"], random_data["t1"])
111+
# numerical inaccuries prevent np.array_equal to work here
112+
assert np.allclose(inverse["p1"], random_data["p1"])

0 commit comments

Comments
 (0)