Skip to content

Commit ea93ae6

Browse files
committed
Add adapter.reshape()
1 parent c1407df commit ea93ae6

File tree

4 files changed

+3146
-0
lines changed

4 files changed

+3146
-0
lines changed

bayesflow/adapters/adapter.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
NumpyTransform,
2222
OneHot,
2323
Rename,
24+
Reshape,
2425
SerializableCustomTransform,
2526
Squeeze,
2627
Sqrt,
@@ -746,6 +747,25 @@ def rename(self, from_key: str, to_key: str):
746747
self.transforms.append(Rename(from_key, to_key))
747748
return self
748749

750+
def reshape(self, keys: str | Sequence[str], *, to: int | Sequence[int]):
751+
"""Append a :py:class:`~transforms.Reshape` transform to the adapter.
752+
753+
Parameters
754+
----------
755+
keys : str or Sequence of str
756+
Variables that should be reshaped
757+
to : int or tuple of int
758+
Target shape of the variables
759+
"""
760+
from .transforms import Reshape
761+
762+
if isinstance(keys, str):
763+
keys = [keys]
764+
765+
transform = MapTransform({key: Reshape(shape=to) for key in keys})
766+
self.transforms.append(transform)
767+
return self
768+
749769
def scale(self, keys: str | Sequence[str], by: float | np.ndarray):
750770
from .transforms import Scale
751771

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .numpy_transform import NumpyTransform
1616
from .one_hot import OneHot
1717
from .rename import Rename
18+
from .reshape import Reshape
1819
from .scale import Scale
1920
from .serializable_custom_transform import SerializableCustomTransform
2021
from .shift import Shift
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import numpy as np
2+
3+
from collections.abc import Sequence
4+
from bayesflow.utils.serialization import serializable, serialize
5+
6+
from .elementwise_transform import ElementwiseTransform
7+
8+
9+
@serializable("bayesflow.adapters")
10+
class Reshape(ElementwiseTransform):
11+
12+
def __init__(self, shape: int | Sequence[int]):
13+
super().__init__()
14+
15+
if isinstance(shape, Sequence):
16+
shape = tuple(shape)
17+
self.shape = shape
18+
19+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
20+
return np.reshape(data, self.shape)
21+
22+
23+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
24+
return np.reshape(data, self.shape)
25+
26+
27+
def get_config(self) -> dict:
28+
return {"shape": self.shape}

0 commit comments

Comments
 (0)