Skip to content

Commit e25b27a

Browse files
committed
refactor TensorProcessingInfo
1 parent 5f44ba0 commit e25b27a

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

bioimageio/core/prediction_pipeline/_combined_processing.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,22 @@
1212

1313

1414
@dataclasses.dataclass
15-
class ProcessingInfo:
15+
class ProcessingInfoStep:
1616
name: str
1717
kwargs: Dict[str, Any]
1818

1919

2020
@dataclasses.dataclass
21-
class TensorProcessingInfo:
22-
processing_steps: List[ProcessingInfo]
21+
class ProcessingInfo:
22+
steps: List[ProcessingInfoStep]
2323
assert_dtype_before: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match
2424
ensure_dtype_before: Optional[str] = None # cast data type if needed
2525
assert_dtype_after: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match
2626
ensure_dtype_after: Optional[str] = None # throw AssertionError if data type doesn't match
2727

2828

2929
class CombinedProcessing:
30-
def __init__(self, combine_tensors: Dict[TensorName, TensorProcessingInfo]):
30+
def __init__(self, combine_tensors: Dict[TensorName, ProcessingInfo]):
3131
self._procs = []
3232
known = dict(KNOWN_PROCESSING["pre"])
3333
known.update(KNOWN_PROCESSING["post"])
@@ -41,7 +41,7 @@ def __init__(self, combine_tensors: Dict[TensorName, TensorProcessingInfo]):
4141
self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.ensure_dtype_before))
4242

4343
for tensor_name, info in combine_tensors.items():
44-
for step in info.processing_steps:
44+
for step in info.steps:
4545
self._procs.append(known[step.name](tensor_name=tensor_name, **step.kwargs))
4646

4747
if info.assert_dtype_after is not None:
@@ -65,13 +65,13 @@ def from_tensor_specs(cls, tensor_specs: List[Union[nodes.InputTensor, nodes.Out
6565
if isinstance(ts, nodes.InputTensor):
6666
# todo: assert nodes.InputTensor.dtype with assert_dtype_before?
6767
# todo: in the long run we do not want to limit model inputs to float32...
68-
combine_tensors[ts.name] = TensorProcessingInfo(
69-
[ProcessingInfo(p.name, kwargs=p.kwargs) for p in ts.preprocessing or []],
68+
combine_tensors[ts.name] = ProcessingInfo(
69+
[ProcessingInfoStep(p.name, kwargs=p.kwargs) for p in ts.preprocessing or []],
7070
ensure_dtype_after="float32",
7171
)
7272
elif isinstance(ts, nodes.OutputTensor):
73-
combine_tensors[ts.name] = TensorProcessingInfo(
74-
[ProcessingInfo(p.name, kwargs=p.kwargs) for p in ts.postprocessing or []],
73+
combine_tensors[ts.name] = ProcessingInfo(
74+
[ProcessingInfoStep(p.name, kwargs=p.kwargs) for p in ts.postprocessing or []],
7575
ensure_dtype_after=ts.data_type,
7676
)
7777
else:

0 commit comments

Comments
 (0)