11import dataclasses
2- from typing import Dict , Iterable , List , NamedTuple , Optional , Sequence , TypedDict , Union
2+ from typing import Any , Dict , List , Optional , Sequence , Union
33
44from bioimageio .core .resource_io import nodes
5- from ._processing import EnsureDtype , KNOWN_PROCESSING , Processing , TensorName
5+ from ._processing import AssertDtype , EnsureDtype , KNOWN_PROCESSING , Processing , TensorName
66from ._utils import ComputedMeasures , PER_DATASET , PER_SAMPLE , RequiredMeasures , Sample
77
88try :
1111 from typing_extensions import Literal # type: ignore
1212
1313
14+ @dataclasses .dataclass
15+ class Processing :
16+ name : str
17+ kwargs : Dict [str , Any ]
18+
19+
1420@dataclasses .dataclass
1521class TensorProcessingInfo :
16- processing_steps : Union [List [nodes .Preprocessing ], List [nodes .Postprocessing ]]
17- data_type_before : Optional [str ] = None
18- data_type_after : Optional [str ] = None
22+ processing_steps : List [Processing ]
23+ assert_dtype_before : Optional [Union [str , Sequence [str ]]] = None # throw AssertionError if data type doesn't match
24+ ensure_dtype_before : Optional [str ] = None # cast data type if needed
25+ assert_dtype_after : Optional [Union [str , Sequence [str ]]] = None # throw AssertionError if data type doesn't match
26+ ensure_dtype_after : Optional [str ] = None # throw AssertionError if data type doesn't match
1927
2028
2129class CombinedProcessing :
@@ -26,16 +34,22 @@ def __init__(self, combine_tensors: Dict[TensorName, TensorProcessingInfo]):
2634
2735 # ensure all tensors have correct data type before any processing
2836 for tensor_name , info in combine_tensors .items ():
29- if info .data_type_before is not None :
30- self ._procs .append (EnsureDtype (tensor_name = tensor_name , dtype = info .data_type_before ))
37+ if info .assert_dtype_before is not None :
38+ self ._procs .append (AssertDtype (tensor_name = tensor_name , dtype = info .assert_dtype_before ))
39+
40+ if info .ensure_dtype_before is not None :
41+ self ._procs .append (EnsureDtype (tensor_name = tensor_name , dtype = info .ensure_dtype_before ))
3142
3243 for tensor_name , info in combine_tensors .items ():
3344 for step in info .processing_steps :
3445 self ._procs .append (known [step .name ](tensor_name = tensor_name , ** step .kwargs ))
3546
47+ if info .assert_dtype_after is not None :
48+ self ._procs .append (AssertDtype (tensor_name = tensor_name , dtype = info .assert_dtype_after ))
49+
3650 # ensure tensor has correct data type right after its processing
37- if info .data_type_after is not None :
38- self ._procs .append (EnsureDtype (tensor_name = tensor_name , dtype = info .data_type_after ))
51+ if info .ensure_dtype_after is not None :
52+ self ._procs .append (EnsureDtype (tensor_name = tensor_name , dtype = info .ensure_dtype_after ))
3953
4054 self .required_measures : RequiredMeasures = self ._collect_required_measures (self ._procs )
4155 self .tensor_names = list (combine_tensors )
@@ -50,10 +64,15 @@ def from_tensor_specs(cls, tensor_specs: List[Union[nodes.InputTensor, nodes.Out
5064 # todo: cast dtype for inputs before preprocessing? or check dtype?
5165 assert ts .name not in combine_tensors
5266 if isinstance (ts , nodes .InputTensor ):
53- # todo: move preprocessing ensure_dtype here as data_type_after
54- combine_tensors [ts .name ] = TensorProcessingInfo (ts .preprocessing )
67+ # todo: assert nodes.InputTensor.dtype with assert_dtype_before?
68+ combine_tensors [ts .name ] = TensorProcessingInfo (
69+ [Processing (p .name , kwargs = p .kwargs ) for p in ts .preprocessing ]
70+ )
5571 elif isinstance (ts , nodes .OutputTensor ):
56- combine_tensors [ts .name ] = TensorProcessingInfo (ts .postprocessing , None , ts .data_type )
72+ combine_tensors [ts .name ] = TensorProcessingInfo (
73+ [Processing (p .name , kwargs = p .kwargs ) for p in ts .postprocessing ],
74+ ensure_dtype_after = ts .data_type ,
75+ )
5776 else :
5877 raise NotImplementedError (type (ts ))
5978
0 commit comments