|
1 | | -from typing import List, Optional, Sequence, Union |
| 1 | +import dataclasses |
| 2 | +from typing import Dict, Iterable, List, NamedTuple, Optional, Sequence, TypedDict, Union |
2 | 3 |
|
3 | 4 | from bioimageio.core.resource_io import nodes |
4 | | -from ._processing import EnsureDtype, KNOWN_PROCESSING, Processing |
| 5 | +from ._processing import EnsureDtype, KNOWN_PROCESSING, Processing, TensorName |
5 | 6 | from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample |
6 | 7 |
|
7 | 8 | try: |
|
10 | 11 | from typing_extensions import Literal # type: ignore |
11 | 12 |
|
12 | 13 |
|
| 14 | +@dataclasses.dataclass |
| 15 | +class TensorProcessingInfo: |
| 16 | + processing_steps: Union[List[nodes.Preprocessing], List[nodes.Postprocessing]] |
| 17 | + data_type_before: Optional[str] = None |
| 18 | + data_type_after: Optional[str] = None |
| 19 | + |
| 20 | + |
13 | 21 | class CombinedProcessing: |
14 | | - def __init__(self, tensor_specs: Union[List[nodes.InputTensor], List[nodes.OutputTensor]]): |
15 | | - PRE: Literal["pre"] = "pre" |
16 | | - POST: Literal["post"] = "post" |
17 | | - proc_prefix: Optional[Literal["pre", "post"]] = None |
| 22 | + def __init__(self, combine_tensors: Dict[TensorName, TensorProcessingInfo]): |
18 | 23 | self._procs = [] |
19 | | - for t in tensor_specs: |
20 | | - if isinstance(t, nodes.InputTensor): |
21 | | - steps = t.preprocessing or [] |
22 | | - if proc_prefix is not None and proc_prefix != PRE: |
23 | | - raise ValueError(f"Invalid mixed input/output tensor specs: {tensor_specs}") |
24 | | - |
25 | | - proc_prefix = PRE |
26 | | - elif isinstance(t, nodes.OutputTensor): |
27 | | - steps = t.postprocessing or [] |
28 | | - if proc_prefix is not None and proc_prefix != POST: |
29 | | - raise ValueError(f"Invalid mixed input/output tensor specs: {tensor_specs}") |
30 | | - |
31 | | - proc_prefix = POST |
32 | | - else: |
33 | | - raise NotImplementedError(t) |
| 24 | + known = dict(KNOWN_PROCESSING["pre"]) |
| 25 | + known.update(KNOWN_PROCESSING["post"]) |
| 26 | + |
| 27 | + # ensure all tensors have correct data type before any processing |
| 28 | + for tensor_name, info in combine_tensors.items(): |
| 29 | + if info.data_type_before is not None: |
| 30 | + self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.data_type_before)) |
34 | 31 |
|
35 | | - for step in steps: |
36 | | - self._procs.append(KNOWN_PROCESSING[proc_prefix][step.name](tensor_name=t.name, **step.kwargs)) |
| 32 | + for tensor_name, info in combine_tensors.items(): |
| 33 | + for step in info.processing_steps: |
| 34 | + self._procs.append(known[step.name](tensor_name=tensor_name, **step.kwargs)) |
37 | 35 |
|
38 | | - # There is a difference between pre-and postprocessing: |
39 | | - # Preprocessing always returns float32, because its output is consumed by the model. |
40 | | - # Postprocessing, however, should return the dtype that is specified in the model spec. |
41 | | - # todo: cast dtype for inputs before preprocessing? or check dtype? |
42 | | - if proc_prefix == POST: |
43 | | - for t in tensor_specs: |
44 | | - self._procs.append(EnsureDtype(tensor_name=t.name, dtype=t.data_type)) |
| 36 | + # ensure tensor has correct data type right after its processing |
| 37 | + if info.data_type_after is not None: |
| 38 | + self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.data_type_after)) |
45 | 39 |
|
46 | 40 | self.required_measures: RequiredMeasures = self._collect_required_measures(self._procs) |
47 | | - if proc_prefix == POST and self.required_measures[PER_DATASET]: |
48 | | - raise NotImplementedError("computing statistics for output tensors per dataset is not yet implemented") |
| 41 | + self.tensor_names = list(combine_tensors) |
| 42 | + |
| 43 | + @classmethod |
| 44 | + def from_tensor_specs(cls, tensor_specs: List[Union[nodes.InputTensor, nodes.OutputTensor]]): |
| 45 | + combine_tensors = {} |
| 46 | + for ts in tensor_specs: |
| 47 | + # There is a difference between pre-and postprocessing: |
| 48 | + # Preprocessing always returns float32, because its output is consumed by the model. |
| 49 | + # Postprocessing, however, should return the dtype that is specified in the model spec. |
| 50 | + # todo: cast dtype for inputs before preprocessing? or check dtype? |
| 51 | + assert ts.name not in combine_tensors |
| 52 | + if isinstance(ts, nodes.InputTensor): |
| 53 | + # todo: move preprocessing ensure_dtype here as data_type_after |
| 54 | + combine_tensors[ts.name] = TensorProcessingInfo(ts.preprocessing) |
| 55 | + elif isinstance(ts, nodes.OutputTensor): |
| 56 | + combine_tensors[ts.name] = TensorProcessingInfo(ts.postprocessing, None, ts.data_type) |
| 57 | + else: |
| 58 | + raise NotImplementedError(type(ts)) |
| 59 | + |
| 60 | + inst = cls(combine_tensors) |
| 61 | + for ts in tensor_specs: |
| 62 | + if isinstance(ts, nodes.OutputTensor) and ts.name in inst.required_measures[PER_DATASET]: |
| 63 | + raise NotImplementedError("computing statistics for output tensors per dataset is not yet implemented") |
49 | 64 |
|
50 | | - self.tensor_names = [t.name for t in tensor_specs] |
| 65 | + return inst |
51 | 66 |
|
52 | 67 | def apply(self, sample: Sample, computed_measures: ComputedMeasures) -> None: |
53 | 68 | for proc in self._procs: |
|
0 commit comments