Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
FilterTransform,
Keep,
LambdaTransform,
Log,
MapTransform,
OneHot,
Rename,
Sqrt,
Standardize,
ToArray,
Transform,
Expand Down Expand Up @@ -500,6 +502,23 @@ def keep(self, keys: str | Sequence[str]):
self.transforms.append(transform)
return self

def log(self, keys: str | Sequence[str], *, p1: bool = False):
"""Append an :py:class:`~transforms.Log` transform to the adapter.

Parameters
----------
keys : str or Sequence of str
The names of the variables to expand.
p1 : boolean
Add 1 to the input before taking the logarithm?
"""
if isinstance(keys, str):
keys = [keys]

transform = Log(keys, p1=p1)
self.transforms.append(transform)
return self

def one_hot(self, keys: str | Sequence[str], num_classes: int):
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.

Expand Down Expand Up @@ -530,6 +549,21 @@ def rename(self, from_key: str, to_key: str):
self.transforms.append(Rename(from_key, to_key))
return self

def sqrt(self, keys: str | Sequence[str]):
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.

Parameters
----------
keys : str or Sequence of str
The names of the variables to expand.
"""
if isinstance(keys, str):
keys = [keys]

transform = Sqrt(keys)
self.transforms.append(transform)
return self

def standardize(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from .filter_transform import FilterTransform
from .keep import Keep
from .lambda_transform import LambdaTransform
from .log import Log
from .map_transform import MapTransform
from .one_hot import OneHot
from .rename import Rename
from .sqrt import Sqrt
from .standardize import Standardize
from .to_array import ToArray
from .transform import Transform
Expand Down
43 changes: 43 additions & 0 deletions bayesflow/adapters/transforms/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np

from collections.abc import Sequence

from .elementwise_transform import ElementwiseTransform


class Log(ElementwiseTransform):
"""Log transforms a variable.

Parameters
----------
p1 : boolean
Add 1 to the input before taking the logarithm?

Examples
--------
>>> adapter = bf.Adapter().log(["x"])
"""

def __init__(self, keys: Sequence[str], *, p1: bool = False):
super().__init__()
self.keys = keys
self.p1 = p1

def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
if self.p1:
return {k: (np.log1p(v) if k in self.keys else v) for k, v in data.items()}
else:
return {k: (np.log(v) if k in self.keys else v) for k, v in data.items()}

def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
if self.p1:
return {k: (np.expm1(v) if k in self.keys else v) for k, v in data.items()}
else:
return {k: (np.exp(v) if k in self.keys else v) for k, v in data.items()}

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Log":
return cls()

def get_config(self) -> dict:
return {}
31 changes: 31 additions & 0 deletions bayesflow/adapters/transforms/sqrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np

from collections.abc import Sequence

from .elementwise_transform import ElementwiseTransform


class Sqrt(ElementwiseTransform):
"""Square-root transform a variable.

Examples
--------
>>> adapter = bf.Adapter().sqrt(["x"])
"""

def __init__(self, keys: Sequence[str]):
super().__init__()
self.keys = keys

def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
return {k: (np.sqrt(v) if k in self.keys else v) for k, v in data.items()}

def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
return {k: (np.square(v) if k in self.keys else v) for k, v in data.items()}

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Sqrt":
return cls()

def get_config(self) -> dict:
return {}
14 changes: 14 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,17 @@ def test_constrain():
assert np.isinf(result["x_upper_disc2"][0])
assert np.isneginf(result["x_both_disc2"][0])
assert np.isinf(result["x_both_disc2"][-1])


def test_log_sqrt(random_data):
# check if constraint-implied transforms are applied correctly
from bayesflow.adapters import Adapter

adapter = Adapter().log(["o1", "p2"]).log("t1", p1=True).sqrt("p1")

result = adapter(random_data)

assert np.isfinite(result["o1"][0, 0])
assert np.isfinite(result["p2"][0, 0])
assert np.isfinite(result["t1"][0, 0])
assert np.isfinite(result["p1"][0, 0])