Skip to content

Commit b95428d

Browse files
committed
add scale and shift transforms
1 parent 22c75d1 commit b95428d

File tree

4 files changed

+44
-0
lines changed

4 files changed

+44
-0
lines changed

bayesflow/adapters/adapter.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,18 @@ def rename(self, from_key: str, to_key: str):
555555
self.transforms.append(Rename(from_key, to_key))
556556
return self
557557

558+
def scale(self, keys: str | Sequence[str], by: float | np.ndarray):
559+
from .transforms import Scale
560+
561+
self.transforms.append(MapTransform({key: Scale(scale=by) for key in keys}))
562+
return self
563+
564+
def shift(self, keys: str | Sequence[str], by: float | np.ndarray):
565+
from .transforms import Shift
566+
567+
self.transforms.append(MapTransform({key: Shift(shift=by) for key in keys}))
568+
return self
569+
558570
def sqrt(self, keys: str | Sequence[str]):
559571
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
560572

bayesflow/adapters/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from .numpy_transform import NumpyTransform
1515
from .one_hot import OneHot
1616
from .rename import Rename
17+
from .scale import Scale
18+
from .shift import Shift
1719
from .sqrt import Sqrt
1820
from .standardize import Standardize
1921
from .to_array import ToArray
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import numpy as np
2+
3+
from .elementwise_transform import ElementwiseTransform
4+
5+
6+
class Scale(ElementwiseTransform):
7+
def __init__(self, scale: float | np.ndarray):
8+
self.scale = scale
9+
10+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
11+
return data * self.scale
12+
13+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
14+
return data / self.scale
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
3+
from .elementwise_transform import ElementwiseTransform
4+
5+
class Shift(ElementwiseTransform):
6+
def __init__(self, shift: float | np.ndarray):
7+
self.shift = shift
8+
9+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
10+
return data + self.shift
11+
12+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
13+
return data - self.shift
14+
15+
16+

0 commit comments

Comments
 (0)