22
33import numpy as np
44
5- from keras .saving import (
6- deserialize_keras_object as deserialize ,
7- register_keras_serializable as serializable ,
8- serialize_keras_object as serialize ,
9- )
5+ from bayesflow .utils .serialization import deserialize , serialize , serializable
106
117from .transforms import (
128 AsSet ,
3329from .transforms .filter_transform import Predicate
3430
3531
36- @serializable ( package = "bayesflow.adapters" )
32+ @serializable
3733class Adapter (MutableSequence [Transform ]):
3834 """
3935 Defines an adapter to apply various transforms to data.
@@ -74,18 +70,24 @@ def create_default(inference_variables: Sequence[str]) -> "Adapter":
7470
7571 @classmethod
7672 def from_config (cls , config : dict , custom_objects = None ) -> "Adapter" :
77- return cls (transforms = deserialize (config [ "transforms" ], custom_objects ))
73+ return cls (** deserialize (config , custom_objects = custom_objects ))
7874
7975 def get_config (self ) -> dict :
80- return {"transforms" : serialize (self .transforms )}
76+ config = {
77+ "transforms" : self .transforms ,
78+ }
79+
80+ return serialize (config )
8181
82- def forward (self , data : dict [str , any ], ** kwargs ) -> dict [str , np .ndarray ]:
82+ def forward (self , data : dict [str , any ], * , stage : str = "inference" , * *kwargs ) -> dict [str , np .ndarray ]:
8383 """Apply the transforms in the forward direction.
8484
8585 Parameters
8686 ----------
8787 data : dict
8888 The data to be transformed.
89+ stage : str, one of ["training", "validation", "inference"]
90+ The stage the function is called in.
8991 **kwargs : dict
9092 Additional keyword arguments passed to each transform.
9193
@@ -97,17 +99,19 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
9799 data = data .copy ()
98100
99101 for transform in self .transforms :
100- data = transform (data , ** kwargs )
102+ data = transform (data , stage = stage , ** kwargs )
101103
102104 return data
103105
104- def inverse (self , data : dict [str , np .ndarray ], ** kwargs ) -> dict [str , any ]:
106+ def inverse (self , data : dict [str , np .ndarray ], * , stage : str = "inference" , * *kwargs ) -> dict [str , any ]:
105107 """Apply the transforms in the inverse direction.
106108
107109 Parameters
108110 ----------
109111 data : dict
110112 The data to be transformed.
113+ stage : str, one of ["training", "validation", "inference"]
114+ The stage the function is called in.
111115 **kwargs : dict
112116 Additional keyword arguments passed to each transform.
113117
@@ -119,11 +123,13 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
119123 data = data .copy ()
120124
121125 for transform in reversed (self .transforms ):
122- data = transform (data , inverse = True , ** kwargs )
126+ data = transform (data , stage = stage , inverse = True , ** kwargs )
123127
124128 return data
125129
126- def __call__ (self , data : Mapping [str , any ], * , inverse : bool = False , ** kwargs ) -> dict [str , np .ndarray ]:
130+ def __call__ (
131+ self , data : Mapping [str , any ], * , inverse : bool = False , stage = "inference" , ** kwargs
132+ ) -> dict [str , np .ndarray ]:
127133 """Apply the transforms in the given direction.
128134
129135 Parameters
@@ -132,6 +138,8 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs)
132138 The data to be transformed.
133139 inverse : bool, optional
134140 If False, apply the forward transform, else apply the inverse transform (default False).
141+ stage : str, one of ["training", "validation", "inference"]
142+ The stage the function is called in.
135143 **kwargs
136144 Additional keyword arguments passed to each transform.
137145
@@ -141,9 +149,9 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs)
141149 The transformed data.
142150 """
143151 if inverse :
144- return self .inverse (data , ** kwargs )
152+ return self .inverse (data , stage = stage , ** kwargs )
145153
146- return self .forward (data , ** kwargs )
154+ return self .forward (data , stage = stage , ** kwargs )
147155
148156 def __repr__ (self ):
149157 result = ""
@@ -667,6 +675,18 @@ def shift(self, keys: str | Sequence[str], by: float | np.ndarray):
667675 self .transforms .append (MapTransform ({key : Shift (shift = by ) for key in keys }))
668676 return self
669677
678+ def split (self , key : str , * , into : Sequence [str ], indices_or_sections : int | Sequence [int ] = None , axis : int = - 1 ):
679+ from .transforms import Split
680+
681+ if isinstance (into , str ):
682+ transform = Rename (key , into )
683+ else :
684+ transform = Split (key , into , indices_or_sections , axis )
685+
686+ self .transforms .append (transform )
687+
688+ return self
689+
670690 def sqrt (self , keys : str | Sequence [str ]):
671691 """Append an :py:class:`~transforms.Sqrt` transform to the adapter.
672692
@@ -743,3 +763,10 @@ def to_array(
743763 )
744764 self .transforms .append (transform )
745765 return self
766+
767+ def to_dict (self ):
768+ from .transforms import ToDict
769+
770+ transform = ToDict ()
771+ self .transforms .append (transform )
772+ return self
0 commit comments