Skip to content

Commit 33b416c

Browse files
committed
Add squeeze transform
Very basic transform, just the inverse of expand_dims
1 parent 92426d6 commit 33b416c

File tree

3 files changed

+20
-0
lines changed

3 files changed

+20
-0
lines changed

bayesflow/adapters/adapter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
OneHot,
2323
Rename,
2424
SerializableCustomTransform,
25+
Squeeze,
2526
Sqrt,
2627
Standardize,
2728
ToArray,
@@ -780,6 +781,23 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq
780781

781782
return self
782783

784+
def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple):
785+
"""Append a :py:class:`~transforms.Squeeze` transform to the adapter.
786+
787+
Parameters
788+
----------
789+
keys : str or Sequence of str
790+
The names of the variables to squeeze.
791+
axis : int or tuple
792+
The axis to squeeze.
793+
"""
794+
if isinstance(keys, str):
795+
keys = [keys]
796+
797+
transform = MapTransform({key: Squeeze(axis=axis) for key in keys})
798+
self.transforms.append(transform)
799+
return self
800+
783801
def sqrt(self, keys: str | Sequence[str]):
784802
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
785803

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .serializable_custom_transform import SerializableCustomTransform
2020
from .shift import Shift
2121
from .split import Split
22+
from .squeeze import Squeeze
2223
from .sqrt import Sqrt
2324
from .standardize import Standardize
2425
from .to_array import ToArray

tests/test_adapters/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def serializable_fn(x):
2121
.concatenate(["x1", "x2"], into="x")
2222
.concatenate(["y1", "y2"], into="y")
2323
.expand_dims(["z1"], axis=2)
24+
.squeeze("z1", axis=2)
2425
.log("p1")
2526
.constrain("p2", lower=0)
2627
.apply(include="p2", forward="exp", inverse="log")

0 commit comments

Comments
 (0)