55 serialize_keras_object as serialize ,
66)
77
8- from collections .abc import Sequence
98from .elementwise_transform import ElementwiseTransform
109
1110
@@ -15,8 +14,6 @@ class ExpandDims(ElementwiseTransform):
1514
1615 Parameters
1716 ----------
18- keys : str or Sequence of str
19- The names of the variables to expand.
2017 axis : int or tuple
2118 The axis to expand.
2219
@@ -49,29 +46,23 @@ class ExpandDims(ElementwiseTransform):
4946 It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform.
5047 """
5148
52- def __init__ (self , keys : Sequence [ str ], * , axis : int | tuple ):
49+ def __init__ (self , * , axis : int | tuple ):
5350 super ().__init__ ()
54-
55- self .keys = keys
5651 self .axis = axis
5752
5853 @classmethod
5954 def from_config (cls , config : dict , custom_objects = None ) -> "ExpandDims" :
6055 return cls (
61- keys = deserialize (config ["keys" ], custom_objects ),
6256 axis = deserialize (config ["axis" ], custom_objects ),
6357 )
6458
6559 def get_config (self ) -> dict :
6660 return {
67- "keys" : serialize (self .keys ),
6861 "axis" : serialize (self .axis ),
6962 }
7063
71- # noinspection PyMethodOverriding
72- def forward (self , data : dict [str , any ], ** kwargs ) -> dict [str , np .ndarray ]:
73- return {k : (np .expand_dims (v , axis = self .axis ) if k in self .keys else v ) for k , v in data .items ()}
64+ def forward (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
65+ return np .expand_dims (data , axis = self .axis )
7466
75- # noinspection PyMethodOverriding
76- def inverse (self , data : dict [str , any ], ** kwargs ) -> dict [str , np .ndarray ]:
77- return {k : (np .squeeze (v , axis = self .axis ) if k in self .keys else v ) for k , v in data .items ()}
67+ def inverse (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
68+ return np .squeeze (data , axis = self .axis )
0 commit comments