11from collections .abc import Sequence
22import numpy as np
33
4+ from keras .saving import (
5+ deserialize_keras_object as deserialize ,
6+ register_keras_serializable as serializable ,
7+ serialize_keras_object as serialize ,
8+ )
9+
410from .transform import Transform
511
612
13+ @serializable (package = "bayesflow.adapters" )
714class Split (Transform ):
815 """This is the effective inverse of the :py:class:`~Concatenate` Transform.
916
@@ -29,6 +36,23 @@ def __init__(self, key: str, into: Sequence[str], indices_or_sections: int | Seq
2936
3037 self .indices_or_sections = indices_or_sections
3138
39+ @classmethod
40+ def from_config (cls , config : dict , custom_objects = None ) -> "Split" :
41+ return cls (
42+ key = deserialize (config ["key" ], custom_objects ),
43+ into = deserialize (config ["into" ], custom_objects ),
44+ indices_or_sections = deserialize (config ["indices_or_sections" ], custom_objects ),
45+ axis = deserialize (config ["axis" ], custom_objects ),
46+ )
47+
48+ def get_config (self ) -> dict :
49+ return {
50+ "key" : serialize (self .key ),
51+ "into" : serialize (self .into ),
52+ "indices_or_sections" : serialize (self .indices_or_sections ),
53+ "axis" : serialize (self .axis ),
54+ }
55+
3256 def forward (self , data : dict [str , np .ndarray ], strict : bool = True , ** kwargs ) -> dict [str , np .ndarray ]:
3357 # avoid side effects
3458 data = data .copy ()
@@ -39,7 +63,7 @@ def forward(self, data: dict[str, np.ndarray], strict: bool = True, **kwargs) ->
3963 # we cannot produce a result, but also don't have to
4064 return data
4165
42- splits = np .split (data .pop (self .key ), self .indices_or_sections )
66+ splits = np .split (data .pop (self .key ), self .indices_or_sections , axis = self . axis )
4367
4468 if len (splits ) != len (self .into ):
4569 raise ValueError (f"Requested { len (self .into )} splits, but produced { len (splits )} ." )
0 commit comments