File tree Expand file tree Collapse file tree 2 files changed +31
-7
lines changed
bayesflow/adapters/transforms Expand file tree Collapse file tree 2 files changed +31
-7
lines changed Original file line number Diff line number Diff line change 1+ from keras .saving import (
2+ deserialize_keras_object as deserialize ,
3+ register_keras_serializable as serializable ,
4+ serialize_keras_object as serialize ,
5+ )
16import numpy as np
27
38from .elementwise_transform import ElementwiseTransform
49
510
11+ @serializable (package = "bayesflow.adapters" )
612class Scale (ElementwiseTransform ):
7- def __init__ (self , scale : float | np .ndarray ):
8- self .scale = scale
13+ def __init__ (self , scale : np .typing .ArrayLike ):
14+ self .scale = np .array (scale )
15+
16+ @classmethod
17+ def from_config (cls , config : dict , custom_objects = None ) -> "ElementwiseTransform" :
18+ return cls (scale = deserialize (config ["scale" ]))
19+
20+ def get_config (self ) -> dict :
21+ return {"scale" : serialize (self .scale )}
922
1023 def forward (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
1124 return data * self .scale
Original file line number Diff line number Diff line change 1+ from keras .saving import (
2+ deserialize_keras_object as deserialize ,
3+ register_keras_serializable as serializable ,
4+ serialize_keras_object as serialize ,
5+ )
16import numpy as np
27
38from .elementwise_transform import ElementwiseTransform
49
10+
11+ @serializable (package = "bayesflow.adapters" )
512class Shift (ElementwiseTransform ):
6- def __init__ (self , shift : float | np .ndarray ):
7- self .shift = shift
13+ def __init__ (self , shift : np .typing .ArrayLike ):
14+ self .shift = np .array (shift )
15+
16+ @classmethod
17+ def from_config (cls , config : dict , custom_objects = None ) -> "ElementwiseTransform" :
18+ return cls (shift = deserialize (config ["shift" ]))
19+
20+ def get_config (self ) -> dict :
21+ return {"shift" : serialize (self .shift )}
822
923 def forward (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
1024 return data + self .shift
1125
1226 def inverse (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
1327 return data - self .shift
14-
15-
16-
You can’t perform that action at this time.
0 commit comments