1- from dataclasses import dataclass
2- from typing import (
3- Any ,
4- Dict ,
5- List ,
6- Literal ,
7- Optional ,
8- Sequence ,
9- Set ,
10- Type ,
11- Union ,
12- get_args ,
13- )
1+ from dataclasses import dataclass , field
2+ from typing import Any , Dict , Literal , Optional , Sequence , Set , get_args
143
154import numpy as np
165import xarray as xr
176
18- from bioimageio .core .resource_io import nodes
197from bioimageio .core .statistical_measures import Mean , Measure , Percentile , Std
20- from bioimageio .spec .model .v0_3 .raw_nodes import PreprocessingName
218
229
2310def ensure_dtype (tensor : xr .DataArray , * , dtype ) -> xr .DataArray :
@@ -29,8 +16,11 @@ def ensure_dtype(tensor: xr.DataArray, *, dtype) -> xr.DataArray:
2916
3017@dataclass
3118class Processing :
32- apply_to : str
33- computed_statistics : Dict [str , Dict [Measure , Any ]]
19+ """base class for all Pre- and Postprocessing transformations"""
20+
21+ tensor_name : str
22+ computed_dataset_statistics : Dict [str , Dict [Measure , Any ]] = field (init = False )
23+ computed_sample_statistics : Dict [str , Dict [Measure , Any ]] = field (init = False )
3424
3525 def get_required_dataset_statistics (self ) -> Dict [str , Set [Measure ]]:
3626 """
@@ -39,30 +29,48 @@ def get_required_dataset_statistics(self) -> Dict[str, Set[Measure]]:
3929 """
4030 return {}
4131
42- def set_computed_statistics (self , computed : Dict [str , Dict [Measure , Any ]]):
32+ def get_required_sample_statistics (self ) -> Dict [str , Set [Measure ]]:
33+ """
34+ Specifies which sample measures are required from what tensor.
35+ Returns: sample measures required to apply this processing indexed by <tensor_name>.
36+ """
37+
38+ def set_computed_dataset_statistics (self , computed : Dict [str , Dict [Measure , Any ]]):
4339 """helper to set computed statistics and check if they match the requirements"""
4440 for tensor_name , req_measures in self .get_required_dataset_statistics ():
4541 comp_measures = computed .get (tensor_name , {})
4642 for req_measure in req_measures :
4743 if req_measure not in comp_measures :
48- raise ValueError ("Missing required measure {req_measure} for {tensor_name}" )
49- self .computed_statistics = computed
44+ raise ValueError (f"Missing required measure { req_measure } for { tensor_name } " )
45+ self .computed_dataset_statistics = computed
46+
47+ def set_computed_sample_statistics (self , computed : Dict [str , Dict [Measure , Any ]]):
48+ """helper to set computed statistics and check if they match the requirements"""
49+ for tensor_name , req_measures in self .get_required_sample_statistics ():
50+ comp_measures = computed .get (tensor_name , {})
51+ for req_measure in req_measures :
52+ if req_measure not in comp_measures :
53+ raise ValueError (f"Missing required measure { req_measure } for { tensor_name } " )
54+ self .computed_sample_statistics = computed
5055
51- def get_computed_statistics (self , tensor_name : str , measure : Measure ):
52- """helper to unpack self.computed_statistics """
53- ret = self .computed_statistics .get (tensor_name , {}).get (measure )
56+ def get_computed_dataset_statistics (self , tensor_name : str , measure : Measure ):
57+ """helper to unpack self.computed_dataset_statistics """
58+ ret = self .computed_dataset_statistics .get (tensor_name , {}).get (measure )
5459 if ret is None :
5560 raise RuntimeError (f"Missing computed { measure } for { tensor_name } dataset." )
5661
5762 return ret
5863
59- def apply (self , ** tensors : xr .DataArray ) -> Dict [str , xr .DataArray ]:
60- """apply processing to named tensors; call 'apply_simple' as default"""
61- tensors [self .apply_to ] = self .apply_simple (tensors [self .apply_to ])
62- return tensors
64+ def get_computed_sample_statistics (self , tensor_name : str , measure : Measure ):
65+ """helper to unpack self.computed_sample_statistics"""
66+ ret = self .computed_sample_statistics .get (tensor_name , {}).get (measure )
67+ if ret is None :
68+ raise RuntimeError (f"Missing computed { measure } for { tensor_name } sample." )
6369
64- def apply_simple (self , tensor : xr .DataArray ) -> xr .DataArray :
65- """apply processing to single tensor"""
70+ return ret
71+
72+ def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
73+ """apply processing to named tensors"""
6674 raise NotImplementedError
6775
6876 def __post_init__ (self ):
@@ -72,6 +80,28 @@ def __post_init__(self):
7280 raise NotImplementedError (f"Unsupported mode { self .mode } for { self .__class__ .__name__ } : { self .mode } " )
7381
7482
83+ #
84+ # Pre- and Postprocessing implementations
85+ #
86+
87+
88+ @dataclass
89+ class Binarize (Processing ):
90+ threshold : float
91+
92+ def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
93+ return ensure_dtype (tensor > self .threshold , dtype = "float32" )
94+
95+
96+ @dataclass
97+ class Clip (Processing ):
98+ min : float
99+ max : float
100+
101+ def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
102+ return ensure_dtype (tensor .clip (min = self .min , max = self .max ), dtype = "float32" )
103+
104+
75105@dataclass
76106class ScaleLinear (Processing ):
77107 """scale the tensor with a fixed multiplicative and additive factor"""
@@ -93,48 +123,8 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
93123
94124
95125@dataclass
96- class ZeroMeanUnitVariance (Processing ):
97- mode : Literal ["fixed" , "per_sample" , "per_dataset" ] = "per_sample"
98- mean : Optional [float ] = None
99- std : Optional [float ] = None
100- axes : Optional [Sequence [str ]] = None
101- eps : float = 1.0e-6
102-
103- def get_required_dataset_statistics (self ) -> Dict [str , Set [Measure ]]:
104- if self .mode == "per_dataset" :
105- return {self .apply_to : {Mean (), Std ()}}
106- else :
107- return {}
108-
109- def apply (self , ** tensors : xr .DataArray ) -> Dict [str , xr .DataArray ]:
110- tensor = tensors [self .apply_to ]
111- if self .mode == "fixed" :
112- assert self .mean is not None and self .std is not None
113- mean , std = self .mean , self .std
114- elif self .mode == "per_sample" :
115- if self .axes :
116- axes = tuple (self .axes )
117- mean , std = tensor .mean (axes ), tensor .std (axes )
118- else :
119- mean , std = tensor .mean (), tensor .std ()
120- elif self .mode == "per_dataset" :
121- mean = self .get_computed_statistics (self .apply_to , "mean" )
122- std = self .get_computed_statistics (self .apply_to , "std" )
123- else :
124- raise ValueError (self .mode )
125-
126- tensor = (tensor - mean ) / (std + self .eps )
127- tensors [self .apply_to ] = ensure_dtype (tensor , dtype = "float32" )
128-
129- return tensors
130-
131-
132- @dataclass
133- class Binarize (Processing ):
134- threshold : float
135-
136- def apply_simple (self , tensor : xr .DataArray ) -> xr .DataArray :
137- return ensure_dtype (tensor > self .threshold , dtype = "float32" )
126+ class ScaleMeanVariance (Processing ):
127+ ...
138128
139129
140130@dataclass
@@ -150,39 +140,37 @@ def get_required_dataset_statistics(self) -> Dict[str, Set[Measure]]:
150140 return {}
151141 elif self .mode == "per_dataset" :
152142 measures = {Percentile (self .min_percentile ), Percentile (self .max_percentile )}
153- return {self .reference_tensor or self .apply_to : measures }
143+ return {self .reference_tensor or self .tensor_name : measures }
154144 else :
155145 raise ValueError (self .mode )
156146
157- def apply (self , ** tensors : xr .DataArray ) -> Dict [str , xr .DataArray ]:
158- ref_name = self .reference_tensor or self .apply_to
147+ def get_required_sample_statistics (self ) -> Dict [str , Set [Measure ]]:
159148 if self .mode == "per_sample" :
160- ref_tensor = tensors [ref_name ]
161- if self .axes :
162- axes = tuple (self .axes )
163- else :
164- axes = None
165-
166- v_lower = ref_tensor .quantile (self .min_percentile / 100.0 , dim = axes )
167- v_upper = ref_tensor .quantile (self .max_percentile / 100.0 , dim = axes )
149+ measures = {Percentile (self .min_percentile ), Percentile (self .max_percentile )}
150+ return {self .reference_tensor or self .tensor_name : measures }
168151 elif self .mode == "per_dataset" :
169- v_lower = self .get_computed_statistics (ref_name , Percentile (self .min_percentile ))
170- v_upper = self .get_computed_statistics (ref_name , Percentile (self .max_percentile ))
152+ return {}
171153 else :
172154 raise ValueError (self .mode )
173155
174- tensors [self .apply_to ] = ensure_dtype ((tensors [self .apply_to ] - v_lower ) / v_upper , dtype = "float32" )
175- return tensors
156+ def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
157+ ref_name = self .reference_tensor or self .tensor_name
158+ if self .axes :
159+ axes = tuple (self .axes )
160+ else :
161+ axes = None
176162
163+ if self .mode == "per_sample" :
164+ get_stat = self .get_computed_sample_statistics
165+ elif self .mode == "per_dataset" :
166+ get_stat = self .get_computed_dataset_statistics
167+ else :
168+ raise ValueError (self .mode )
177169
178- # todo: continue here....
179- @dataclass
180- class Clip (Processing ):
181- min : float
182- max : float
170+ v_lower = get_stat (ref_name , Percentile (self .min_percentile , axes = axes ))
171+ v_upper = get_stat (ref_name , Percentile (self .max_percentile , axes = axes ))
183172
184- def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
185- return ensure_dtype (tensor .clip (min = self .min , max = self .max ), dtype = "float32" )
173+ return ensure_dtype ((tensor - v_lower ) / v_upper , dtype = "float32" )
186174
187175
188176@dataclass
@@ -191,26 +179,41 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
191179 return 1.0 / (1.0 + xr .ufuncs .exp (- tensor ))
192180
193181
194- KNOWN_PREPROCESSING : Dict [PreprocessingName , Type [Processing ]] = {
195- "scale_linear" : ScaleLinear ,
196- "zero_mean_unit_variance" : ZeroMeanUnitVariance ,
197- "binarize" : Binarize ,
198- "clip" : Clip ,
199- "scale_range" : ScaleRange ,
200- "sigmoid" : Sigmoid ,
201- }
202-
203-
204- class CombinedProcessing :
205- def __init__ (
206- self ,
207- processing_spec : Union [List [nodes .Preprocessing ], List [nodes .Postprocessing ]],
208- input_tensor_names : Sequence [str ],
209- output_tensor_names : Sequence [str ] = tuple (),
210- ):
211- prep = all (isinstance (ps , nodes .Preprocessing ) for ps in processing_spec )
212- assert prep or all (isinstance (ps , nodes .Postprocessing ) for ps in processing_spec )
213-
214- self .tensor_names = input_tensor_names if prep else output_tensor_names
215- self .tensor_names = input_tensor_names if prep else output_tensor_names
216- self .procs = [KNOWN_PREPROCESSING .get (step .name )(** step .kwargs ) for step in processing_spec ]
182+ @dataclass
183+ class ZeroMeanUnitVariance (Processing ):
184+ mode : Literal ["fixed" , "per_sample" , "per_dataset" ] = "per_sample"
185+ mean : Optional [float ] = None
186+ std : Optional [float ] = None
187+ axes : Optional [Sequence [str ]] = None
188+ eps : float = 1.0e-6
189+
190+ def get_required_dataset_statistics (self ) -> Dict [str , Set [Measure ]]:
191+ if self .mode == "per_dataset" :
192+ return {self .tensor_name : {Mean (), Std ()}}
193+ else :
194+ return {}
195+
196+ def get_required_sample_statistics (self ) -> Dict [str , Set [Measure ]]:
197+ if self .mode == "per_sample" :
198+ return {self .tensor_name : {Mean (), Std ()}}
199+ else :
200+ return {}
201+
202+ def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
203+ axes = None if self .axes is None else tuple (self .axes )
204+ if self .mode == "fixed" :
205+ assert self .mean is not None and self .std is not None
206+ mean , std = self .mean , self .std
207+ elif self .mode == "per_sample" :
208+ if axes :
209+ mean , std = tensor .mean (axes ), tensor .std (axes )
210+ else :
211+ mean , std = tensor .mean (), tensor .std ()
212+ elif self .mode == "per_dataset" :
213+ mean = self .get_computed_dataset_statistics (self .tensor_name , Mean (axes ))
214+ std = self .get_computed_dataset_statistics (self .tensor_name , Std (axes ))
215+ else :
216+ raise ValueError (self .mode )
217+
218+ tensor = (tensor - mean ) / (std + self .eps )
219+ return ensure_dtype (tensor , dtype = "float32" )
0 commit comments