Skip to content

Commit 49b45d7

Browse files
committed
use map transform
1 parent fc60b39 commit 49b45d7

File tree

6 files changed

+34
-44
lines changed

6 files changed

+34
-44
lines changed

bayesflow/adapters/adapter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def expand_dims(self, keys: str | Sequence[str], *, axis: int | tuple):
483483
if isinstance(keys, str):
484484
keys = [keys]
485485

486-
transform = ExpandDims(keys, axis=axis)
486+
transform = MapTransform({key: ExpandDims(axis=axis) for key in keys})
487487
self.transforms.append(transform)
488488
return self
489489

@@ -508,14 +508,14 @@ def log(self, keys: str | Sequence[str], *, p1: bool = False):
508508
Parameters
509509
----------
510510
keys : str or Sequence of str
511-
The names of the variables to expand.
511+
The names of the variables to transform.
512512
p1 : boolean
513513
Add 1 to the input before taking the logarithm?
514514
"""
515515
if isinstance(keys, str):
516516
keys = [keys]
517517

518-
transform = Log(keys, p1=p1)
518+
transform = MapTransform({key: Log(p1=p1) for key in keys})
519519
self.transforms.append(transform)
520520
return self
521521

@@ -555,12 +555,12 @@ def sqrt(self, keys: str | Sequence[str]):
555555
Parameters
556556
----------
557557
keys : str or Sequence of str
558-
The names of the variables to expand.
558+
The names of the variables to transform.
559559
"""
560560
if isinstance(keys, str):
561561
keys = [keys]
562562

563-
transform = Sqrt(keys)
563+
transform = MapTransform({key: Sqrt() for key in keys})
564564
self.transforms.append(transform)
565565
return self
566566

bayesflow/adapters/transforms/expand_dims.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
serialize_keras_object as serialize,
66
)
77

8-
from collections.abc import Sequence
98
from .elementwise_transform import ElementwiseTransform
109

1110

@@ -15,8 +14,6 @@ class ExpandDims(ElementwiseTransform):
1514
1615
Parameters
1716
----------
18-
keys : str or Sequence of str
19-
The names of the variables to expand.
2017
axis : int or tuple
2118
The axis to expand.
2219
@@ -49,29 +46,23 @@ class ExpandDims(ElementwiseTransform):
4946
It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform.
5047
"""
5148

52-
def __init__(self, keys: Sequence[str], *, axis: int | tuple):
49+
def __init__(self, *, axis: int | tuple):
5350
super().__init__()
54-
55-
self.keys = keys
5651
self.axis = axis
5752

5853
@classmethod
5954
def from_config(cls, config: dict, custom_objects=None) -> "ExpandDims":
6055
return cls(
61-
keys=deserialize(config["keys"], custom_objects),
6256
axis=deserialize(config["axis"], custom_objects),
6357
)
6458

6559
def get_config(self) -> dict:
6660
return {
67-
"keys": serialize(self.keys),
6861
"axis": serialize(self.axis),
6962
}
7063

71-
# noinspection PyMethodOverriding
72-
def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
73-
return {k: (np.expand_dims(v, axis=self.axis) if k in self.keys else v) for k, v in data.items()}
64+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
65+
return np.expand_dims(data, axis=self.axis)
7466

75-
# noinspection PyMethodOverriding
76-
def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
77-
return {k: (np.squeeze(v, axis=self.axis) if k in self.keys else v) for k, v in data.items()}
67+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
68+
return np.squeeze(data, axis=self.axis)
Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22

3-
from collections.abc import Sequence
3+
from keras.saving import (
4+
deserialize_keras_object as deserialize,
5+
serialize_keras_object as serialize,
6+
)
47

58
from .elementwise_transform import ElementwiseTransform
69

@@ -18,26 +21,29 @@ class Log(ElementwiseTransform):
1821
>>> adapter = bf.Adapter().log(["x"])
1922
"""
2023

21-
def __init__(self, keys: Sequence[str], *, p1: bool = False):
24+
def __init__(self, *, p1: bool = False):
2225
super().__init__()
23-
self.keys = keys
2426
self.p1 = p1
2527

26-
def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
28+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
2729
if self.p1:
28-
return {k: (np.log1p(v) if k in self.keys else v) for k, v in data.items()}
30+
return np.log1p(data)
2931
else:
30-
return {k: (np.log(v) if k in self.keys else v) for k, v in data.items()}
32+
return np.log(data)
3133

32-
def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
34+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
3335
if self.p1:
34-
return {k: (np.expm1(v) if k in self.keys else v) for k, v in data.items()}
36+
return np.expm1(data)
3537
else:
36-
return {k: (np.exp(v) if k in self.keys else v) for k, v in data.items()}
38+
return np.exp(data)
3739

3840
@classmethod
3941
def from_config(cls, config: dict, custom_objects=None) -> "Log":
40-
return cls()
42+
return cls(
43+
p1=deserialize(config["p1"], custom_objects),
44+
)
4145

4246
def get_config(self) -> dict:
43-
return {}
47+
return {
48+
"p1": serialize(self.p1),
49+
}

bayesflow/adapters/transforms/sqrt.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import numpy as np
22

3-
from collections.abc import Sequence
4-
53
from .elementwise_transform import ElementwiseTransform
64

75

@@ -13,15 +11,11 @@ class Sqrt(ElementwiseTransform):
1311
>>> adapter = bf.Adapter().sqrt(["x"])
1412
"""
1513

16-
def __init__(self, keys: Sequence[str]):
17-
super().__init__()
18-
self.keys = keys
19-
20-
def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
21-
return {k: (np.sqrt(v) if k in self.keys else v) for k, v in data.items()}
14+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
15+
return np.sqrt(data)
2216

23-
def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
24-
return {k: (np.square(v) if k in self.keys else v) for k, v in data.items()}
17+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
18+
return np.square(data)
2519

2620
@classmethod
2721
def from_config(cls, config: dict, custom_objects=None) -> "Sqrt":

tests/test_adapters/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ def adapter():
3030
.concatenate(["y1", "y2"], into="y")
3131
.expand_dims(["z1"], axis=2)
3232
.apply(forward=forward_transform, inverse=inverse_transform)
33-
# TODO: fix this in keras
34-
# .apply(include="p1", forward=np.log, inverse=np.exp)
33+
.log("p1")
3534
.constrain("p2", lower=0)
3635
.standardize(exclude=["t1", "t2", "o1"])
3736
.drop("d1")

tests/test_adapters/test_adapters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def test_constrain():
8989
assert np.isinf(result["x_both_disc2"][-1])
9090

9191

92-
def test_log_sqrt(random_data):
93-
# check if constraint-implied transforms are applied correctly
92+
def test_simple_transforms(random_data):
93+
# check if simple transforms are applied correctly
9494
from bayesflow.adapters import Adapter
9595

9696
adapter = Adapter().log(["o1", "p2"]).log("t1", p1=True).sqrt("p1")

0 commit comments

Comments
 (0)