1212 from typing_extensions import Literal , get_args # type: ignore
1313
1414
15+ def _get_fixed (
16+ fixed : Union [float , Sequence [float ]], tensor : xr .DataArray , axes : Optional [Sequence [str ]]
17+ ) -> Union [float , xr .DataArray ]:
18+ if axes is None :
19+ return fixed
20+
21+ fixed_shape = tuple (s for d , s in tensor .sizes .items () if d not in axes )
22+ fixed_dims = tuple (d for d in tensor .dims if d not in axes )
23+ fixed = np .array (fixed ).reshape (fixed_shape )
24+ return xr .DataArray (fixed , dims = fixed_dims )
25+
26+
1527@dataclass
1628class Processing :
1729 """base class for all Pre- and Postprocessing transformations"""
@@ -226,8 +238,8 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
226238@dataclass
227239class ZeroMeanUnitVariance (Processing ):
228240 mode : Literal ["fixed" , "per_sample" , "per_dataset" ] = "per_sample"
229- mean : Optional [float ] = None
230- std : Optional [float ] = None
241+ mean : Optional [Union [ float , Sequence [ float ]] ] = None
242+ std : Optional [Union [ float , Sequence [ float ]] ] = None
231243 axes : Optional [Sequence [str ]] = None
232244 eps : float = 1.0e-6
233245
@@ -247,12 +259,11 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
247259 axes = None if self .axes is None else tuple (self .axes )
248260 if self .mode == "fixed" :
249261 assert self .mean is not None and self .std is not None
250- mean , std = self .mean , self .std
262+ mean = _get_fixed (self .mean , tensor , axes )
263+ std = _get_fixed (self .std , tensor , axes )
251264 elif self .mode == "per_sample" :
252- if axes :
253- mean , std = Mean (axes ).compute (tensor ), Std (axes ).compute (tensor )
254- else :
255- mean , std = tensor .mean (), tensor .std ()
265+ mean = Mean (axes ).compute (tensor )
266+ std = Std (axes ).compute (tensor )
256267 elif self .mode == "per_dataset" :
257268 mean = self .get_computed_dataset_statistics (self .tensor_name , Mean (axes ))
258269 std = self .get_computed_dataset_statistics (self .tensor_name , Std (axes ))
0 commit comments