11import numpy as np
2+ from keras .saving import register_keras_serializable as serializable
23
34from bayesflow .utils import filter_kwargs
45from .elementwise_transform import ElementwiseTransform
56
67
8+ @serializable (package = "bayesflow.adapters" )
79class NumpyTransform (ElementwiseTransform ):
810 """
911 A class to apply element-wise transformations using plain NumPy functions.
@@ -17,45 +19,46 @@ class NumpyTransform(ElementwiseTransform):
1719 """
1820
1921 INVERSE_METHODS = {
20- " arctan" : " tan" ,
21- " exp" : " log" ,
22- " expm1" : " log1p" ,
23- " square" : " sqrt" ,
24- " reciprocal" : " reciprocal" ,
22+ np . arctan : np . tan ,
23+ np . exp : np . log ,
24+ np . expm1 : np . log1p ,
25+ np . square : np . sqrt ,
26+ np . reciprocal : np . reciprocal ,
2527 }
2628 # ensure the map is symmetric
2729 INVERSE_METHODS |= {v : k for k , v in INVERSE_METHODS .items ()}
2830
29- def __init__ (self , forward : np . ufunc | str , inverse : np . ufunc | str = None ):
31+ def __init__ (self , forward : str , inverse : str = None ):
3032 """
3133 Initializes the NumpyTransform with specified forward and inverse functions.
3234
3335 Parameters:
3436 ----------
35- forward : str
37+ forward: str
3638 The name of the NumPy function to use for the forward transformation.
37- inverse : str
39+ inverse: str, optional
3840 The name of the NumPy function to use for the inverse transformation.
3941 By default, the inverse is inferred from the forward argument for supported methods.
4042 """
4143 super ().__init__ ()
4244
43- if isinstance (forward , np .ufunc ):
44- forward = forward .__name__
45+ if isinstance (forward , str ):
46+ forward = getattr (np , forward )
47+
48+ if not isinstance (forward , np .ufunc ):
49+ raise ValueError ("Forward transformation must be a NumPy Universal Function (ufunc)." )
4550
4651 if inverse is None :
4752 if forward not in self .INVERSE_METHODS :
4853 raise ValueError (f"Cannot infer inverse for method { forward !r} " )
4954
5055 inverse = self .INVERSE_METHODS [forward ]
51- elif isinstance (inverse , np .ufunc ):
52- inverse = inverse .__name__
5356
54- if forward not in dir ( np ):
55- raise ValueError ( f"Method { forward !r } not found in numpy version { np . __version__ } " )
57+ if isinstance ( inverse , str ):
58+ inverse = getattr ( np , inverse )
5659
57- if inverse not in dir ( np ):
58- raise ValueError (f"Method { inverse !r } not found in numpy version { np . __version__ } " )
60+ if not isinstance ( inverse , np . ufunc ):
61+ raise ValueError ("Inverse transformation must be a NumPy Universal Function (ufunc). " )
5962
6063 self ._forward = forward
6164 self ._inverse = inverse
@@ -68,14 +71,12 @@ def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform
6871 )
6972
7073 def get_config (self ) -> dict :
71- return {"forward" : self ._forward , "inverse" : self ._inverse }
74+ return {"forward" : self ._forward . __name__ , "inverse" : self ._inverse . __name__ }
7275
7376 def forward (self , data : dict [str , any ], ** kwargs ) -> dict [str , any ]:
74- forward = getattr (np , self ._forward )
75- kwargs = filter_kwargs (kwargs , forward )
76- return forward (data , ** kwargs )
77+ kwargs = filter_kwargs (kwargs , self ._forward )
78+ return self ._forward (data , ** kwargs )
7779
7880 def inverse (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
79- inverse = getattr (np , self ._inverse )
80- kwargs = filter_kwargs (kwargs , inverse )
81- return inverse (data , ** kwargs )
81+ kwargs = filter_kwargs (kwargs , self ._inverse )
82+ return self ._inverse (data , ** kwargs )
0 commit comments