Skip to content

Commit 0ddaad5

Browse files
committed
make scale and shift serializable
1 parent fa0d6c3 commit 0ddaad5

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

bayesflow/adapters/transforms/scale.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,24 @@
1+
from keras.saving import (
2+
deserialize_keras_object as deserialize,
3+
register_keras_serializable as serializable,
4+
serialize_keras_object as serialize,
5+
)
16
import numpy as np
27

38
from .elementwise_transform import ElementwiseTransform
49

510

11+
@serializable(package="bayesflow.adapters")
612
class Scale(ElementwiseTransform):
7-
def __init__(self, scale: float | np.ndarray):
8-
self.scale = scale
13+
def __init__(self, scale: np.typing.ArrayLike):
14+
self.scale = np.array(scale)
15+
16+
@classmethod
17+
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
18+
return cls(scale=deserialize(config["scale"]))
19+
20+
def get_config(self) -> dict:
21+
return {"scale": serialize(self.scale)}
922

1023
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
1124
return data * self.scale

bayesflow/adapters/transforms/shift.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
1+
from keras.saving import (
2+
deserialize_keras_object as deserialize,
3+
register_keras_serializable as serializable,
4+
serialize_keras_object as serialize,
5+
)
16
import numpy as np
27

38
from .elementwise_transform import ElementwiseTransform
49

10+
11+
@serializable(package="bayesflow.adapters")
512
class Shift(ElementwiseTransform):
6-
def __init__(self, shift: float | np.ndarray):
7-
self.shift = shift
13+
def __init__(self, shift: np.typing.ArrayLike):
14+
self.shift = np.array(shift)
15+
16+
@classmethod
17+
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
18+
return cls(shift=deserialize(config["shift"]))
19+
20+
def get_config(self) -> dict:
21+
return {"shift": serialize(self.shift)}
822

923
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
1024
return data + self.shift
1125

1226
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
1327
return data - self.shift
14-
15-
16-

0 commit comments

Comments
 (0)