33import xarray as xr
44from bioimageio .core .resource_io .nodes import Postprocessing
55
6- from ._preprocessing import binarize , chain
6+ from . import _preprocessing as ops
77from ._types import Transform
88
99
10- def sigmoid (tensor : xr .DataArray , ** kwargs ):
11- if kwargs :
12- raise NotImplementedError (f"Passed kwargs for sigmoid { kwargs } " )
13- return 1 / (1 + xr .ufuncs .exp (- tensor ))
10+ # TODO how do we implement reference_tensor?
1411
1512
16- KNOWN_POSTPROCESSING = {"binarize" : binarize , "sigmoid" : sigmoid }
13+ def scale_range (
14+ tensor : xr .DataArray ,
15+ * ,
16+ reference_tensor = None ,
17+ mode = "per_sample" ,
18+ axes = None ,
19+ min_percentile = 0.0 ,
20+ max_percentile = 100.0 ,
21+ ) -> xr .DataArray :
1722
23+ # TODO if reference tensor is passed, we need to use it to compute quantiles instead of 'tensor'
24+ if reference_tensor is None :
25+ tensor_ = tensor
26+ else :
27+ raise NotImplementedError
1828
19- def make_postprocessing (spec : List [Postprocessing ]) -> Transform :
29+ # valid modes according to spec: "per_sample", "per_dataset"
30+ # TODO implement per_dataset
31+ if mode != "per_sample" :
32+ raise NotImplementedError (f"Unsupported mode for scale_range: { mode } " )
33+
34+ if axes :
35+ axes = tuple (axes )
36+ v_lower = tensor_ .quantile (min_percentile / 100.0 , dim = axes )
37+ v_upper = tensor_ .quantile (max_percentile / 100.0 , dim = axes )
38+ else :
39+ v_lower = tensor_ .quantile (min_percentile / 100.0 )
40+ v_upper = tensor_ .quantile (max_percentile / 100.0 )
41+
42+ return ops .ensure_dtype ((tensor - v_lower ) / v_upper , dtype = "float32" )
43+
44+
45+ # TODO scale the tensor s.t. it matches the mean and variance of the reference tensor
46+ def scale_mean_variance (tensor : xr .DataArray , * , reference_tensor , mode = "per_sample" ):
47+ raise NotImplementedError
48+
49+
50+ KNOWN_POSTPROCESSING = {
51+ "binarize" : ops .binarize ,
52+ "clip" : ops .clip ,
53+ "scale_linear" : ops .scale_linear ,
54+ "scale_range" : ops .scale_range ,
55+ "sigmoid" : ops .sigmoid ,
56+ "zero_mean_unit_variance" : ops .zero_mean_unit_variance ,
57+ }
58+
59+
60+ def make_postprocessing (spec : List [Postprocessing ], dtype : str ) -> Transform :
2061 """
2162 :param preprocessing: bioimage-io spec node
2263 """
@@ -32,4 +73,9 @@ def make_postprocessing(spec: List[Postprocessing]) -> Transform:
3273
3374 functions .append ((fn , kwargs ))
3475
35- return chain (* functions )
76+ # There is a difference between pre-and-postprocessing:
77+ # Tre-processing always returns float32, because its output is consumed y the model.
78+ # Post-processing, however, should return the dtype that is specified in the model spec.
79+ functions .append ((ops .ensure_dtype , {"dtype" : dtype }))
80+
81+ return ops .chain (* functions )
0 commit comments