1- from dataclasses import dataclass , field
2- from typing import Any , Dict , Literal , Optional , Sequence , Set , get_args
1+ from dataclasses import dataclass , field , fields
2+ from typing import Any , Dict , Literal , Optional , Sequence , Set , Union , get_args
33
44import numpy as np
55import xarray as xr
66
77from bioimageio .core .statistical_measures import Mean , Measure , Percentile , Std
88
99
10- def ensure_dtype (tensor : xr .DataArray , * , dtype ) -> xr .DataArray :
11- """
12- Convert array to a given datatype
13- """
14- return tensor .astype (dtype )
15-
16-
1710@dataclass
1811class Processing :
1912 """base class for all Pre- and Postprocessing transformations"""
@@ -37,7 +30,7 @@ def get_required_sample_statistics(self) -> Dict[str, Set[Measure]]:
3730
3831 def set_computed_dataset_statistics (self , computed : Dict [str , Dict [Measure , Any ]]):
3932 """helper to set computed statistics and check if they match the requirements"""
40- for tensor_name , req_measures in self .get_required_dataset_statistics ():
33+ for tensor_name , req_measures in self .get_required_dataset_statistics (). items () :
4134 comp_measures = computed .get (tensor_name , {})
4235 for req_measure in req_measures :
4336 if req_measure not in comp_measures :
@@ -46,7 +39,7 @@ def set_computed_dataset_statistics(self, computed: Dict[str, Dict[Measure, Any]
4639
4740 def set_computed_sample_statistics (self , computed : Dict [str , Dict [Measure , Any ]]):
4841 """helper to set computed statistics and check if they match the requirements"""
49- for tensor_name , req_measures in self .get_required_sample_statistics ():
42+ for tensor_name , req_measures in self .get_required_sample_statistics (). items () :
5043 comp_measures = computed .get (tensor_name , {})
5144 for req_measure in req_measures :
5245 if req_measure not in comp_measures :
@@ -69,15 +62,35 @@ def get_computed_sample_statistics(self, tensor_name: str, measure: Measure):
6962
7063 return ret
7164
65+ def __call__ (self , tensor : xr .DataArray ) -> xr .DataArray :
66+ return self .apply (tensor )
67+
7268 def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
7369 """apply processing to named tensors"""
7470 raise NotImplementedError
7571
7672 def __post_init__ (self ):
7773 """validate common kwargs by their annotations"""
78- if hasattr (self , "mode" ):
79- if self .mode not in get_args (self .mode ):
80- raise NotImplementedError (f"Unsupported mode { self .mode } for { self .__class__ .__name__ } : { self .mode } " )
74+ self .computed_dataset_statistics = {}
75+ self .computed_sample_statistics = {}
76+
77+ for f in fields (self ):
78+ if f .name == "mode" :
79+ assert hasattr (self , "mode" )
80+ if self .mode not in get_args (f .type ):
81+ raise NotImplementedError (
82+ f"Unsupported mode { self .mode } for { self .__class__ .__name__ } : { self .mode } "
83+ )
84+
85+
86+ #
87+ # helpers
88+ #
89+ def ensure_dtype (tensor : xr .DataArray , * , dtype ) -> xr .DataArray :
90+ """
91+ Convert array to a given datatype
92+ """
93+ return tensor .astype (dtype )
8194
8295
8396#
@@ -102,12 +115,20 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
102115 return ensure_dtype (tensor .clip (min = self .min , max = self .max ), dtype = "float32" )
103116
104117
118+ @dataclass
119+ class EnsureDtype (Processing ):
120+ dtype : str
121+
122+ def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
123+ return ensure_dtype (tensor , dtype = self .dtype )
124+
125+
105126@dataclass
106127class ScaleLinear (Processing ):
107128 """scale the tensor with a fixed multiplicative and additive factor"""
108129
109- gain : float
110- offset : float
130+ gain : Union [ float , Sequence [ float ]]
131+ offset : Union [ float , Sequence [ float ]]
111132 axes : Optional [Sequence [str ]] = None
112133
113134 def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
@@ -121,6 +142,12 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
121142
122143 return ensure_dtype (tensor * gain + offset , dtype = "float32" )
123144
145+ def __post_init__ (self ):
146+ super ().__post_init__ ()
147+ if self .axes is None :
148+ assert isinstance (self .gain , (int , float ))
149+ assert isinstance (self .offset , (int , float ))
150+
124151
125152@dataclass
126153class ScaleMeanVariance (Processing ):
0 commit comments