@@ -167,14 +167,20 @@ def get_required_dataset_statistics(self) -> Dict[str, Set[Measure]]:
167167 if self .mode == "per_sample" :
168168 return {}
169169 elif self .mode == "per_dataset" :
170- measures = {Percentile (self .min_percentile ), Percentile (self .max_percentile )}
170+ measures = {
171+ Percentile (self .min_percentile , axes = self .axes ),
172+ Percentile (self .max_percentile , axes = self .axes ),
173+ }
171174 return {self .reference_tensor or self .tensor_name : measures }
172175 else :
173176 raise ValueError (self .mode )
174177
175178 def get_required_sample_statistics (self ) -> Dict [str , Set [Measure ]]:
176179 if self .mode == "per_sample" :
177- measures = {Percentile (self .min_percentile ), Percentile (self .max_percentile )}
180+ measures = {
181+ Percentile (self .min_percentile , axes = self .axes ),
182+ Percentile (self .max_percentile , axes = self .axes ),
183+ }
178184 return {self .reference_tensor or self .tensor_name : measures }
179185 elif self .mode == "per_dataset" :
180186 return {}
@@ -200,6 +206,10 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
200206
201207 return ensure_dtype ((tensor - v_lower ) / v_upper , dtype = "float32" )
202208
209+ def __post_init__ (self ):
210+ super ().__post_init__ ()
211+ self .axes = None if self .axes is None else tuple (self .axes ) # make sure axes is Tuple[str] or None
212+
203213
204214@dataclass
205215class Sigmoid (Processing ):
0 commit comments