Skip to content

Commit 6359d66

Browse files
committed
ensure_dtype('float') via last preprocessing step
1 parent dab09b5 commit 6359d66

File tree

2 files changed

+12
-24
lines changed

2 files changed

+12
-24
lines changed

bioimageio/core/prediction_pipeline/_combined_processing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ def from_tensor_specs(cls, tensor_specs: List[Union[nodes.InputTensor, nodes.Out
5959
combine_tensors = {}
6060
for ts in tensor_specs:
6161
# There is a difference between pre-and postprocessing:
62-
# Preprocessing always returns float32, because its output is consumed by the model.
63-
# Postprocessing, however, should return the dtype that is specified in the model spec.
64-
# todo: cast dtype for inputs before preprocessing? or check dtype?
62+
# After preprocessing we ensure float32, because the output is consumed by the model.
63+
# After postprocessing the dtype that is specified in the model spec needs to be ensured.
6564
assert ts.name not in combine_tensors
6665
if isinstance(ts, nodes.InputTensor):
6766
# todo: assert nodes.InputTensor.dtype with assert_dtype_before?
67+
# todo: in the long run we do not want to limit model inputs to float32...
6868
combine_tensors[ts.name] = TensorProcessingInfo(
69-
[Processing(p.name, kwargs=p.kwargs) for p in ts.preprocessing]
69+
[Processing(p.name, kwargs=p.kwargs) for p in ts.preprocessing], ensure_dtype_after="float32"
7070
)
7171
elif isinstance(ts, nodes.OutputTensor):
7272
combine_tensors[ts.name] = TensorProcessingInfo(

bioimageio/core/prediction_pipeline/_processing.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,6 @@ def __post_init__(self):
9292
raise NotImplementedError(f"Unsupported mode {self.mode} for {self.__class__.__name__}")
9393

9494

95-
#
96-
# helpers
97-
#
98-
def ensure_dtype(tensor: xr.DataArray, *, dtype) -> xr.DataArray:
99-
"""
100-
Convert array to a given datatype
101-
"""
102-
return tensor.astype(dtype)
103-
104-
10595
#
10696
# Pre- and Postprocessing implementations
10797
#
@@ -129,12 +119,12 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
129119

130120
@dataclass
131121
class Binarize(Processing):
132-
"""'output = tensor > threshold' (note: returns float array)."""
122+
"""'output = tensor > threshold'."""
133123

134124
threshold: float = MISSING # make dataclass inheritance work for py<3.10 by using an explicit MISSING value.
135125

136126
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
137-
return ensure_dtype(tensor > self.threshold, dtype="float32")
127+
return tensor > self.threshold
138128

139129

140130
@dataclass
@@ -145,7 +135,7 @@ class Clip(Processing):
145135
max: float = MISSING
146136

147137
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
148-
return ensure_dtype(tensor.clip(min=self.min, max=self.max), dtype="float32")
138+
return tensor.clip(min=self.min, max=self.max)
149139

150140

151141
@dataclass
@@ -155,7 +145,7 @@ class EnsureDtype(Processing):
155145
dtype: str = MISSING
156146

157147
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
158-
return ensure_dtype(tensor, dtype=self.dtype)
148+
return tensor.astype(self.dtype)
159149

160150

161151
@dataclass
@@ -175,7 +165,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
175165
gain = self.gain
176166
offset = self.offset
177167

178-
return ensure_dtype(tensor * gain + offset, dtype="float32")
168+
return tensor * gain + offset
179169

180170
def __post_init__(self):
181171
super().__post_init__()
@@ -210,8 +200,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
210200
ref_mean = self.get_computed_measure(self.reference_tensor, Mean(axes), mode=self.mode)
211201
ref_std = self.get_computed_measure(self.reference_tensor, Std(axes), mode=self.mode)
212202

213-
tensor = (tensor - mean) / (std + self.eps) * (ref_std + self.eps) + ref_mean
214-
return ensure_dtype(tensor, dtype="float32")
203+
return (tensor - mean) / (std + self.eps) * (ref_std + self.eps) + ref_mean
215204

216205

217206
@dataclass
@@ -236,7 +225,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
236225
v_lower = self.get_computed_measure(ref_name, Percentile(self.min_percentile, axes=axes))
237226
v_upper = self.get_computed_measure(ref_name, Percentile(self.max_percentile, axes=axes))
238227

239-
return ensure_dtype((tensor - v_lower) / (v_upper - v_lower + self.eps), dtype="float32")
228+
return (tensor - v_lower) / (v_upper - v_lower + self.eps)
240229

241230
def __post_init__(self):
242231
super().__post_init__()
@@ -281,8 +270,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
281270
else:
282271
raise ValueError(self.mode)
283272

284-
tensor = (tensor - mean) / (std + self.eps)
285-
return ensure_dtype(tensor, dtype="float32")
273+
return (tensor - mean) / (std + self.eps)
286274

287275

288276
_KnownProcessing = TypedDict(

0 commit comments

Comments
 (0)