1212
1313
1414@dataclasses .dataclass
15- class ProcessingInfo :
15+ class ProcessingInfoStep :
1616 name : str
1717 kwargs : Dict [str , Any ]
1818
1919
2020@dataclasses .dataclass
21- class TensorProcessingInfo :
22- processing_steps : List [ProcessingInfo ]
21+ class ProcessingInfo :
22+ steps : List [ProcessingInfoStep ]
2323 assert_dtype_before : Optional [Union [str , Sequence [str ]]] = None # throw AssertionError if data type doesn't match
2424 ensure_dtype_before : Optional [str ] = None # cast data type if needed
2525 assert_dtype_after : Optional [Union [str , Sequence [str ]]] = None # throw AssertionError if data type doesn't match
2626 ensure_dtype_after : Optional [str ] = None # throw AssertionError if data type doesn't match
2727
2828
2929class CombinedProcessing :
30- def __init__ (self , combine_tensors : Dict [TensorName , TensorProcessingInfo ]):
30+ def __init__ (self , combine_tensors : Dict [TensorName , ProcessingInfo ]):
3131 self ._procs = []
3232 known = dict (KNOWN_PROCESSING ["pre" ])
3333 known .update (KNOWN_PROCESSING ["post" ])
@@ -41,7 +41,7 @@ def __init__(self, combine_tensors: Dict[TensorName, TensorProcessingInfo]):
4141 self ._procs .append (EnsureDtype (tensor_name = tensor_name , dtype = info .ensure_dtype_before ))
4242
4343 for tensor_name , info in combine_tensors .items ():
44- for step in info .processing_steps :
44+ for step in info .steps :
4545 self ._procs .append (known [step .name ](tensor_name = tensor_name , ** step .kwargs ))
4646
4747 if info .assert_dtype_after is not None :
@@ -65,13 +65,13 @@ def from_tensor_specs(cls, tensor_specs: List[Union[nodes.InputTensor, nodes.Out
6565 if isinstance (ts , nodes .InputTensor ):
6666 # todo: assert nodes.InputTensor.dtype with assert_dtype_before?
6767 # todo: in the long run we do not want to limit model inputs to float32...
68- combine_tensors [ts .name ] = TensorProcessingInfo (
69- [ProcessingInfo (p .name , kwargs = p .kwargs ) for p in ts .preprocessing or []],
68+ combine_tensors [ts .name ] = ProcessingInfo (
69+ [ProcessingInfoStep (p .name , kwargs = p .kwargs ) for p in ts .preprocessing or []],
7070 ensure_dtype_after = "float32" ,
7171 )
7272 elif isinstance (ts , nodes .OutputTensor ):
73- combine_tensors [ts .name ] = TensorProcessingInfo (
74- [ProcessingInfo (p .name , kwargs = p .kwargs ) for p in ts .postprocessing or []],
73+ combine_tensors [ts .name ] = ProcessingInfo (
74+ [ProcessingInfoStep (p .name , kwargs = p .kwargs ) for p in ts .postprocessing or []],
7575 ensure_dtype_after = ts .data_type ,
7676 )
7777 else :
0 commit comments