33import xarray as xr
44from bioimageio .core .resource_io .nodes import Postprocessing
55
6- from ._preprocessing import binarize , chain
6+ from . import _preprocessing as ops
77from ._types import Transform
88
99
10- def sigmoid (tensor : xr .DataArray , ** kwargs ):
11- if kwargs :
12- raise NotImplementedError (f"Passed kwargs for sigmoid { kwargs } " )
13- return 1 / (1 + xr .ufuncs .exp (- tensor ))
10+ def scale_range (
11+ tensor : xr .DataArray ,
12+ * ,
13+ reference_tensor = None ,
14+ mode = "per_sample" ,
15+ axes = None ,
16+ min_percentile = 0.0 ,
17+ max_percentile = 100.0 ,
18+ ) -> xr .DataArray :
1419
20+ # TODO if reference tensor is passed, we need to use it to compute quantiles instead of 'tensor'
21+ if reference_tensor is None :
22+ tensor_ = tensor
23+ else :
24+ raise NotImplementedError
1525
16- KNOWN_POSTPROCESSING = {"binarize" : binarize , "sigmoid" : sigmoid }
26+ # valid modes according to spec: "per_sample", "per_dataset"
27+ # TODO implement per_dataset
28+ if mode != "per_sample" :
29+ raise NotImplementedError (f"Unsupported mode for scale_range: { mode } " )
30+
31+ if axes :
32+ axes = tuple (axes )
33+ v_lower = tensor_ .quantile (min_percentile / 100.0 , dim = axes )
34+ v_upper = tensor_ .quantile (max_percentile / 100.0 , dim = axes )
35+ else :
36+ v_lower = tensor_ .quantile (min_percentile / 100.0 )
37+ v_upper = tensor_ .quantile (max_percentile / 100.0 )
38+
39+ return ops .ensure_dtype ((tensor - v_lower ) / v_upper , dtype = "float32" )
40+
41+
42+ # TODO scale the tensor s.t. it matches the mean and variance of the reference tensor
43+ def scale_mean_variance (tensor : xr .DataArray , * , reference_tensor , mode = "per_sample" ):
44+ raise NotImplementedError
45+
46+
47+ # NOTE there is a subtle difference between pre-and-postprocessing:
48+ # pre-processing always returns float32, because the post-processing output is consumed
49+ # by the model. Post-processing, however, should return the dtype that is specified in the model spec
50+ # TODO I think the easiest way to implement this is to add dtype is an option to 'make_postprocessing'
51+ # and then apply 'ensure_dtype' to the result of the postprocessing chain
52+ KNOWN_POSTPROCESSING = {
53+ "binarize" : ops .binarize ,
54+ "clip" : ops .clip ,
55+ "scale_linear" : ops .scale_linear ,
56+ "scale_range" : ops .scale_range ,
57+ "sigmoid" : ops .sigmoid ,
58+ "zero_mean_unit_variance" : ops .zero_mean_unit_variance
59+ }
1760
1861
1962def make_postprocessing (spec : List [Postprocessing ]) -> Transform :
@@ -32,4 +75,4 @@ def make_postprocessing(spec: List[Postprocessing]) -> Transform:
3275
3376 functions .append ((fn , kwargs ))
3477
35- return chain (* functions )
78+ return ops . chain (* functions )
0 commit comments