Skip to content

Commit 108cec9

Browse files
committed
adapter: support log and sqrt transforms
1 parent 3ad23d6 commit 108cec9

File tree

5 files changed

+128
-0
lines changed

5 files changed

+128
-0
lines changed

bayesflow/adapters/adapter.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
FilterTransform,
2020
Keep,
2121
LambdaTransform,
22+
Log,
2223
MapTransform,
2324
OneHot,
2425
Rename,
26+
Sqrt,
2527
Standardize,
2628
ToArray,
2729
Transform,
@@ -500,6 +502,23 @@ def keep(self, keys: str | Sequence[str]):
500502
self.transforms.append(transform)
501503
return self
502504

505+
def log(self, keys: str | Sequence[str], *, p1: bool = False):
506+
"""Append an :py:class:`~transforms.Log` transform to the adapter.
507+
508+
Parameters
509+
----------
510+
keys : str or Sequence of str
511+
The names of the variables to expand.
512+
p1 : boolean
513+
Add 1 to the input before taking the logarithm?
514+
"""
515+
if isinstance(keys, str):
516+
keys = [keys]
517+
518+
transform = Log(keys, p1=p1)
519+
self.transforms.append(transform)
520+
return self
521+
503522
def one_hot(self, keys: str | Sequence[str], num_classes: int):
504523
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.
505524
@@ -530,6 +549,21 @@ def rename(self, from_key: str, to_key: str):
530549
self.transforms.append(Rename(from_key, to_key))
531550
return self
532551

552+
def sqrt(self, keys: str | Sequence[str]):
553+
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
554+
555+
Parameters
556+
----------
557+
keys : str or Sequence of str
558+
The names of the variables to expand.
559+
"""
560+
if isinstance(keys, str):
561+
keys = [keys]
562+
563+
transform = Sqrt(keys)
564+
self.transforms.append(transform)
565+
return self
566+
533567
def standardize(
534568
self,
535569
*,

bayesflow/adapters/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from .filter_transform import FilterTransform
1111
from .keep import Keep
1212
from .lambda_transform import LambdaTransform
13+
from .log import Log
1314
from .map_transform import MapTransform
1415
from .one_hot import OneHot
1516
from .rename import Rename
17+
from .sqrt import Sqrt
1618
from .standardize import Standardize
1719
from .to_array import ToArray
1820
from .transform import Transform
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import numpy as np
2+
3+
from collections.abc import Sequence
4+
5+
from .elementwise_transform import ElementwiseTransform
6+
7+
8+
class Log(ElementwiseTransform):
9+
"""Log transforms a variable.
10+
11+
Parameters
12+
----------
13+
p1 : boolean
14+
Add 1 to the input before taking the logarithm?
15+
16+
Examples
17+
--------
18+
>>> adapter = bf.Adapter().log(["x"])
19+
"""
20+
21+
def __init__(self, keys: Sequence[str], *, p1: bool = False):
22+
super().__init__()
23+
self.keys = keys
24+
self.p1 = p1
25+
26+
def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
27+
if self.p1:
28+
return {k: (np.log1p(v) if k in self.keys else v) for k, v in data.items()}
29+
else:
30+
return {k: (np.log(v) if k in self.keys else v) for k, v in data.items()}
31+
32+
def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
33+
if self.p1:
34+
return {k: (np.expm1(v) if k in self.keys else v) for k, v in data.items()}
35+
else:
36+
return {k: (np.exp(v) if k in self.keys else v) for k, v in data.items()}
37+
38+
@classmethod
39+
def from_config(cls, config: dict, custom_objects=None) -> "Log":
40+
return cls()
41+
42+
def get_config(self) -> dict:
43+
return {}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import numpy as np
2+
3+
from collections.abc import Sequence
4+
5+
from .elementwise_transform import ElementwiseTransform
6+
7+
8+
class Sqrt(ElementwiseTransform):
9+
"""Square-root transform a variable.
10+
11+
Examples
12+
--------
13+
>>> adapter = bf.Adapter().sqrt(["x"])
14+
"""
15+
16+
def __init__(self, keys: Sequence[str]):
17+
super().__init__()
18+
self.keys = keys
19+
20+
def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
21+
return {k: (np.sqrt(v) if k in self.keys else v) for k, v in data.items()}
22+
23+
def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
24+
return {k: (np.square(v) if k in self.keys else v) for k, v in data.items()}
25+
26+
@classmethod
27+
def from_config(cls, config: dict, custom_objects=None) -> "Sqrt":
28+
return cls()
29+
30+
def get_config(self) -> dict:
31+
return {}

tests/test_adapters/test_adapters.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,21 @@ def test_constrain():
8787
assert np.isinf(result["x_upper_disc2"][0])
8888
assert np.isneginf(result["x_both_disc2"][0])
8989
assert np.isinf(result["x_both_disc2"][-1])
90+
91+
def test_log_sqrt(random_data):
92+
# check if constraint-implied transforms are applied correctly
93+
from bayesflow.adapters import Adapter
94+
95+
adapter = (
96+
Adapter()
97+
.log(["o1", "p2"])
98+
.log("t1", p1=True)
99+
.sqrt("p1")
100+
)
101+
102+
result = adapter(random_data)
103+
104+
assert np.isfinite(result["o1"][0, 0])
105+
assert np.isfinite(result["p2"][0, 0])
106+
assert np.isfinite(result["t1"][0, 0])
107+
assert np.isfinite(result["p1"][0, 0])

0 commit comments

Comments
 (0)