Skip to content

Commit 4f65c9b

Browse files
committed
initialize Percentile with axes in ScaleRange
1 parent 819f87a commit 4f65c9b

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

bioimageio/core/prediction_pipeline/_processing.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,20 @@ def get_required_dataset_statistics(self) -> Dict[str, Set[Measure]]:
167167
if self.mode == "per_sample":
168168
return {}
169169
elif self.mode == "per_dataset":
170-
measures = {Percentile(self.min_percentile), Percentile(self.max_percentile)}
170+
measures = {
171+
Percentile(self.min_percentile, axes=self.axes),
172+
Percentile(self.max_percentile, axes=self.axes),
173+
}
171174
return {self.reference_tensor or self.tensor_name: measures}
172175
else:
173176
raise ValueError(self.mode)
174177

175178
def get_required_sample_statistics(self) -> Dict[str, Set[Measure]]:
176179
if self.mode == "per_sample":
177-
measures = {Percentile(self.min_percentile), Percentile(self.max_percentile)}
180+
measures = {
181+
Percentile(self.min_percentile, axes=self.axes),
182+
Percentile(self.max_percentile, axes=self.axes),
183+
}
178184
return {self.reference_tensor or self.tensor_name: measures}
179185
elif self.mode == "per_dataset":
180186
return {}
@@ -200,6 +206,10 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
200206

201207
return ensure_dtype((tensor - v_lower) / v_upper, dtype="float32")
202208

209+
def __post_init__(self):
210+
super().__post_init__()
211+
self.axes = None if self.axes is None else tuple(self.axes) # make sure axes is Tuple[str] or None
212+
203213

204214
@dataclass
205215
class Sigmoid(Processing):

0 commit comments

Comments
 (0)