Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
OneHot,
Rename,
SerializableCustomTransform,
Squeeze,
Sqrt,
Standardize,
ToArray,
Expand Down Expand Up @@ -780,6 +781,24 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq

return self

def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple):
"""Append a :py:class:`~transforms.Squeeze` transform to the adapter.

Parameters
----------
keys : str or Sequence of str
The names of the variables to squeeze.
axis : int or tuple
The axis to squeeze. As the number of batch dimensions might change, we advise using negative
numbers (i.e., indexing from the end instead of the start).
"""
if isinstance(keys, str):
keys = [keys]

transform = MapTransform({key: Squeeze(axis=axis) for key in keys})
self.transforms.append(transform)
return self

def sqrt(self, keys: str | Sequence[str]):
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.

Expand Down
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .serializable_custom_transform import SerializableCustomTransform
from .shift import Shift
from .split import Split
from .squeeze import Squeeze
from .sqrt import Sqrt
from .standardize import Standardize
from .to_array import ToArray
Expand Down
43 changes: 43 additions & 0 deletions bayesflow/adapters/transforms/squeeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np

from bayesflow.utils.serialization import serializable, serialize

from .elementwise_transform import ElementwiseTransform


@serializable("bayesflow.adapters")
class Squeeze(ElementwiseTransform):
"""
Squeeze dimensions of an array.

Parameters
----------
axis : int or tuple
The axis to squeeze. As the number of batch dimensions might change, we advise using negative
numbers (i.e., indexing from the end instead of the start).

Examples
--------
shape (3, 1) array:

>>> a = np.array([[1], [2], [3]])

>>> sq = bf.adapters.transforms.Squeeze(axis=-1)
>>> sq.forward(a).shape
(3,)

It is recommended to precede this transform with a :class:`~bayesflow.adapters.transforms.ToArray` transform.
"""

def __init__(self, *, axis: int | tuple):
super().__init__()
self.axis = axis

def get_config(self) -> dict:
return serialize({"axis": self.axis})

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.squeeze(data, axis=self.axis)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.expand_dims(data, axis=self.axis)
1 change: 1 addition & 0 deletions tests/test_adapters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def serializable_fn(x):
.concatenate(["x1", "x2"], into="x")
.concatenate(["y1", "y2"], into="y")
.expand_dims(["z1"], axis=2)
.squeeze("z1", axis=2)
.log("p1")
.constrain("p2", lower=0)
.apply(include="p2", forward="exp", inverse="log")
Expand Down