1- from dataclasses import dataclass , field , fields
2- from typing import Mapping , Optional , Sequence , Type , Union
3-
1+ """Here pre- and postprocessing operations are implemented according to their definitions in bioimageio.spec:
2+ see https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/preprocessing_spec_latest.md
3+ and https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/postprocessing_spec_latest.md
4+ """
5+ import numbers
6+ from dataclasses import InitVar , dataclass , field , fields
7+ from typing import List , Mapping , Optional , Sequence , Tuple , Type , Union
8+
9+ import numpy
410import numpy as np
511import xarray as xr
612
@@ -33,7 +39,7 @@ def _get_fixed(
3339
3440@dataclass
3541class Processing :
36- """base class for all Pre- and Postprocessing transformations"""
42+ """base class for all Pre- and Postprocessing transformations. """
3743
3844 tensor_name : str
3945 # todo: in python>=3.10 we should use dataclasses.KW_ONLY instead of MISSING (see child classes) to make inheritance work properly
@@ -87,48 +93,64 @@ def __post_init__(self):
8793
8894
8995#
90- # helpers
96+ # Pre- and Postprocessing implementations
9197#
92- def ensure_dtype (tensor : xr .DataArray , * , dtype ) -> xr .DataArray :
93- """
94- Convert array to a given datatype
95- """
96- return tensor .astype (dtype )
9798
9899
99- #
100- # Pre- and Postprocessing implementations
101- #
100+ @dataclass
101+ class AssertDtype (Processing ):
102+ """Helper Processing to assert dtype."""
103+
104+ dtype : Union [str , Sequence [str ]] = MISSING
105+ assert_with : Tuple [Type [numpy .dtype ], ...] = field (init = False )
106+
107+ def __post_init__ (self ):
108+ if isinstance (self .dtype , str ):
109+ dtype = [self .dtype ]
110+ else :
111+ dtype = self .dtype
112+
113+ self .assert_with = tuple (type (numpy .dtype (dt )) for dt in dtype )
114+
115+ def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
116+ assert isinstance (tensor .dtype , self .assert_with )
117+ return tensor
102118
103119
104120@dataclass
105121class Binarize (Processing ):
122+ """'output = tensor > threshold'."""
123+
106124 threshold : float = MISSING # make dataclass inheritance work for py<3.10 by using an explicit MISSING value.
107125
108126 def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
109- return ensure_dtype ( tensor > self .threshold , dtype = "float32" )
127+ return tensor > self .threshold
110128
111129
112130@dataclass
113131class Clip (Processing ):
132+ """Limit tensor values to [min, max]."""
133+
114134 min : float = MISSING
115135 max : float = MISSING
116136
117137 def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
118- return ensure_dtype ( tensor .clip (min = self .min , max = self .max ), dtype = "float32" )
138+ return tensor .clip (min = self .min , max = self .max )
119139
120140
121141@dataclass
122142class EnsureDtype (Processing ):
143+ """Helper Processing to cast dtype if needed."""
144+
123145 dtype : str = MISSING
124146
125147 def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
126- return ensure_dtype ( tensor , dtype = self .dtype )
148+ return tensor . astype ( self .dtype )
127149
128150
129151@dataclass
130152class ScaleLinear (Processing ):
131- """scale the tensor with a fixed multiplicative and additive factor"""
153+ """Scale the tensor with a fixed multiplicative and additive factor. """
132154
133155 gain : Union [float , Sequence [float ]] = MISSING
134156 offset : Union [float , Sequence [float ]] = MISSING
@@ -143,7 +165,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
143165 gain = self .gain
144166 offset = self .offset
145167
146- return ensure_dtype ( tensor * gain + offset , dtype = "float32" )
168+ return tensor * gain + offset
147169
148170 def __post_init__ (self ):
149171 super ().__post_init__ ()
@@ -154,11 +176,37 @@ def __post_init__(self):
154176
155177@dataclass
156178class ScaleMeanVariance (Processing ):
157- ...
179+ """Scale the tensor s.t. its mean and variance match a reference tensor."""
180+
181+ mode : Literal [SampleMode , DatasetMode ] = PER_SAMPLE
182+ reference_tensor : TensorName = MISSING
183+ axes : Optional [Sequence [str ]] = None
184+ eps : float = 1e-6
185+
186+ def get_required_measures (self ) -> RequiredMeasures :
187+ axes = None if self .axes is None else tuple (self .axes )
188+ return {
189+ self .mode : {
190+ self .tensor_name : {Mean (axes = axes ), Std (axes = axes )},
191+ self .reference_tensor : {Mean (axes = axes ), Std (axes = axes )},
192+ }
193+ }
194+
195+ def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
196+ axes = None if self .axes is None else tuple (self .axes )
197+ assert self .mode in (PER_SAMPLE , PER_DATASET )
198+ mean = self .get_computed_measure (self .tensor_name , Mean (axes ), mode = self .mode )
199+ std = self .get_computed_measure (self .tensor_name , Std (axes ), mode = self .mode )
200+ ref_mean = self .get_computed_measure (self .reference_tensor , Mean (axes ), mode = self .mode )
201+ ref_std = self .get_computed_measure (self .reference_tensor , Std (axes ), mode = self .mode )
202+
203+ return (tensor - mean ) / (std + self .eps ) * (ref_std + self .eps ) + ref_mean
158204
159205
160206@dataclass
161207class ScaleRange (Processing ):
208+ """Scale with percentiles."""
209+
162210 mode : Literal [SampleMode , DatasetMode ] = PER_SAMPLE
163211 axes : Optional [Sequence [str ]] = None
164212 min_percentile : float = 0.0
@@ -177,7 +225,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
177225 v_lower = self .get_computed_measure (ref_name , Percentile (self .min_percentile , axes = axes ))
178226 v_upper = self .get_computed_measure (ref_name , Percentile (self .max_percentile , axes = axes ))
179227
180- return ensure_dtype (( tensor - v_lower ) / (v_upper - v_lower + self .eps ), dtype = "float32" )
228+ return ( tensor - v_lower ) / (v_upper - v_lower + self .eps )
181229
182230 def __post_init__ (self ):
183231 super ().__post_init__ ()
@@ -186,12 +234,16 @@ def __post_init__(self):
186234
187235@dataclass
188236class Sigmoid (Processing ):
237+ """1 / (1 + e^(-tensor))."""
238+
189239 def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
190240 return 1.0 / (1.0 + np .exp (- tensor ))
191241
192242
193243@dataclass
194244class ZeroMeanUnitVariance (Processing ):
245+ """normalize to zero mean, unit variance."""
246+
195247 mode : Mode = PER_SAMPLE
196248 mean : Optional [Union [float , Sequence [float ]]] = None
197249 std : Optional [Union [float , Sequence [float ]]] = None
@@ -218,8 +270,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
218270 else :
219271 raise ValueError (self .mode )
220272
221- tensor = (tensor - mean ) / (std + self .eps )
222- return ensure_dtype (tensor , dtype = "float32" )
273+ return (tensor - mean ) / (std + self .eps )
223274
224275
225276_KnownProcessing = TypedDict (
0 commit comments