Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
FilterTransform,
Keep,
LambdaTransform,
Log,
MapTransform,
OneHot,
Rename,
Sqrt,
Standardize,
ToArray,
Transform,
Expand Down Expand Up @@ -481,7 +483,7 @@ def expand_dims(self, keys: str | Sequence[str], *, axis: int | tuple):
if isinstance(keys, str):
keys = [keys]

transform = ExpandDims(keys, axis=axis)
transform = MapTransform({key: ExpandDims(axis=axis) for key in keys})
self.transforms.append(transform)
return self

Expand All @@ -500,6 +502,23 @@ def keep(self, keys: str | Sequence[str]):
self.transforms.append(transform)
return self

def log(self, keys: str | Sequence[str], *, p1: bool = False):
"""Append an :py:class:`~transforms.Log` transform to the adapter.

Parameters
----------
keys : str or Sequence of str
The names of the variables to transform.
p1 : boolean
Add 1 to the input before taking the logarithm?
"""
if isinstance(keys, str):
keys = [keys]

transform = MapTransform({key: Log(p1=p1) for key in keys})
self.transforms.append(transform)
return self

def one_hot(self, keys: str | Sequence[str], num_classes: int):
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.

Expand Down Expand Up @@ -530,6 +549,21 @@ def rename(self, from_key: str, to_key: str):
self.transforms.append(Rename(from_key, to_key))
return self

def sqrt(self, keys: str | Sequence[str]):
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.

Parameters
----------
keys : str or Sequence of str
The names of the variables to transform.
"""
if isinstance(keys, str):
keys = [keys]

transform = MapTransform({key: Sqrt() for key in keys})
self.transforms.append(transform)
return self

def standardize(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from .filter_transform import FilterTransform
from .keep import Keep
from .lambda_transform import LambdaTransform
from .log import Log
from .map_transform import MapTransform
from .one_hot import OneHot
from .rename import Rename
from .sqrt import Sqrt
from .standardize import Standardize
from .to_array import ToArray
from .transform import Transform
Expand Down
19 changes: 5 additions & 14 deletions bayesflow/adapters/transforms/expand_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
serialize_keras_object as serialize,
)

from collections.abc import Sequence
from .elementwise_transform import ElementwiseTransform


Expand All @@ -15,8 +14,6 @@ class ExpandDims(ElementwiseTransform):

Parameters
----------
keys : str or Sequence of str
The names of the variables to expand.
axis : int or tuple
The axis to expand.

Expand Down Expand Up @@ -49,29 +46,23 @@ class ExpandDims(ElementwiseTransform):
It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform.
"""

def __init__(self, keys: Sequence[str], *, axis: int | tuple):
def __init__(self, *, axis: int | tuple):
super().__init__()

self.keys = keys
self.axis = axis

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "ExpandDims":
return cls(
keys=deserialize(config["keys"], custom_objects),
axis=deserialize(config["axis"], custom_objects),
)

def get_config(self) -> dict:
return {
"keys": serialize(self.keys),
"axis": serialize(self.axis),
}

# noinspection PyMethodOverriding
def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
return {k: (np.expand_dims(v, axis=self.axis) if k in self.keys else v) for k, v in data.items()}
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.expand_dims(data, axis=self.axis)

# noinspection PyMethodOverriding
def inverse(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
return {k: (np.squeeze(v, axis=self.axis) if k in self.keys else v) for k, v in data.items()}
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.squeeze(data, axis=self.axis)
49 changes: 49 additions & 0 deletions bayesflow/adapters/transforms/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np

from keras.saving import (
deserialize_keras_object as deserialize,
serialize_keras_object as serialize,
)

from .elementwise_transform import ElementwiseTransform


class Log(ElementwiseTransform):
"""Log transforms a variable.

Parameters
----------
p1 : boolean
Add 1 to the input before taking the logarithm?

Examples
--------
>>> adapter = bf.Adapter().log(["x"])
"""

def __init__(self, *, p1: bool = False):
super().__init__()
self.p1 = p1

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
if self.p1:
return np.log1p(data)
else:
return np.log(data)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
if self.p1:
return np.expm1(data)
else:
return np.exp(data)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Log":
return cls(
p1=deserialize(config["p1"], custom_objects),
)

def get_config(self) -> dict:
return {
"p1": serialize(self.p1),
}
25 changes: 25 additions & 0 deletions bayesflow/adapters/transforms/sqrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np

from .elementwise_transform import ElementwiseTransform


class Sqrt(ElementwiseTransform):
"""Square-root transform a variable.

Examples
--------
>>> adapter = bf.Adapter().sqrt(["x"])
"""

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.sqrt(data)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.square(data)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Sqrt":
return cls()

def get_config(self) -> dict:
return {}
3 changes: 1 addition & 2 deletions tests/test_adapters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def adapter():
.concatenate(["y1", "y2"], into="y")
.expand_dims(["z1"], axis=2)
.apply(forward=forward_transform, inverse=inverse_transform)
# TODO: fix this in keras
# .apply(include="p1", forward=np.log, inverse=np.exp)
.log("p1")
.constrain("p2", lower=0)
.standardize(exclude=["t1", "t2", "o1"])
.drop("d1")
Expand Down
23 changes: 23 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,26 @@ def test_constrain():
assert np.isinf(result["x_upper_disc2"][0])
assert np.isneginf(result["x_both_disc2"][0])
assert np.isinf(result["x_both_disc2"][-1])


def test_simple_transforms(random_data):
# check if simple transforms are applied correctly
from bayesflow.adapters import Adapter

adapter = Adapter().log(["p2", "t2"]).log("t1", p1=True).sqrt("p1")

result = adapter(random_data)

assert np.array_equal(result["p2"], np.log(random_data["p2"]))
assert np.array_equal(result["t2"], np.log(random_data["t2"]))
assert np.array_equal(result["t1"], np.log1p(random_data["t1"]))
assert np.array_equal(result["p1"], np.sqrt(random_data["p1"]))

# inverse results should match the original input
inverse = adapter.inverse(result)

assert np.array_equal(inverse["p2"], random_data["p2"])
assert np.array_equal(inverse["t2"], random_data["t2"])
assert np.array_equal(inverse["t1"], random_data["t1"])
# numerical inaccuries prevent np.array_equal to work here
assert np.allclose(inverse["p1"], random_data["p1"])