|
| 1 | +import warnings |
| 2 | +from collections import defaultdict |
| 3 | +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type |
| 4 | + |
| 5 | +import xarray as xr |
| 6 | + |
| 7 | +from bioimageio.core.resource_io import nodes |
| 8 | +from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std |
| 9 | +from bioimageio.spec.model.raw_nodes import PostprocessingName, PreprocessingName |
| 10 | +from ._processing import ( |
| 11 | + Binarize, |
| 12 | + Clip, |
| 13 | + EnsureDtype, |
| 14 | + Processing, |
| 15 | + ScaleLinear, |
| 16 | + ScaleMeanVariance, |
| 17 | + ScaleRange, |
| 18 | + Sigmoid, |
| 19 | + ZeroMeanUnitVariance, |
| 20 | +) |
| 21 | + |
| 22 | +try: |
| 23 | + from typing import Literal |
| 24 | +except ImportError: |
| 25 | + from typing_extensions import Literal # type: ignore |
| 26 | + |
| 27 | +KNOWN_PREPROCESSING: Dict[PreprocessingName, Type[Processing]] = { |
| 28 | + "binarize": Binarize, |
| 29 | + "clip": Clip, |
| 30 | + "scale_linear": ScaleLinear, |
| 31 | + "scale_range": ScaleRange, |
| 32 | + "sigmoid": Sigmoid, |
| 33 | + "zero_mean_unit_variance": ZeroMeanUnitVariance, |
| 34 | +} |
| 35 | + |
| 36 | +KNOWN_POSTPROCESSING: Dict[PostprocessingName, Type[Processing]] = { |
| 37 | + "binarize": Binarize, |
| 38 | + "clip": Clip, |
| 39 | + "scale_linear": ScaleLinear, |
| 40 | + "scale_mean_variance": ScaleMeanVariance, |
| 41 | + "scale_range": ScaleRange, |
| 42 | + "sigmoid": Sigmoid, |
| 43 | + "zero_mean_unit_variance": ZeroMeanUnitVariance, |
| 44 | +} |
| 45 | + |
| 46 | + |
| 47 | +Scope = Literal["sample", "dataset"] |
| 48 | +SAMPLE: Literal["sample"] = "sample" |
| 49 | +DATASET: Literal["dataset"] = "dataset" |
| 50 | +SCOPES: Set[Scope] = {SAMPLE, DATASET} |
| 51 | + |
| 52 | + |
| 53 | +class CombinedProcessing: |
| 54 | + def __init__(self, inputs: List[nodes.InputTensor], outputs: List[nodes.OutputTensor]): |
| 55 | + self._prep = [ |
| 56 | + KNOWN_PREPROCESSING[step.name](tensor_name=ipt.name, **step.kwargs) |
| 57 | + for ipt in inputs |
| 58 | + for step in ipt.preprocessing or [] |
| 59 | + ] |
| 60 | + self._post = [ |
| 61 | + KNOWN_POSTPROCESSING.get(step.name)(tensor_name=out.name, **step.kwargs) |
| 62 | + for out in outputs |
| 63 | + for step in out.postprocessing or [] |
| 64 | + ] |
| 65 | + |
| 66 | + # There is a difference between pre-and-postprocessing: |
| 67 | + # Pre-processing always returns float32, because its output is consumed by the model. |
| 68 | + # Post-processing, however, should return the dtype that is specified in the model spec. |
| 69 | + # todo: cast dtype for inputs before preprocessing? or check dtype? |
| 70 | + for out in outputs: |
| 71 | + self._post.append(EnsureDtype(tensor_name=out.name, dtype=out.data_type)) |
| 72 | + |
| 73 | + self._req_input_stats = {s: self._collect_required_stats(self._prep, s) for s in SCOPES} |
| 74 | + self._req_output_stats = {s: self._collect_required_stats(self._post, s) for s in SCOPES} |
| 75 | + if any(self._req_output_stats[s] for s in SCOPES): |
| 76 | + raise NotImplementedError("computing statistics for output tensors not yet implemented") |
| 77 | + |
| 78 | + self._computed_dataset_stats: Optional[Dict[str, Dict[Measure, Any]]] = None |
| 79 | + |
| 80 | + self.input_tensor_names = [ipt.name for ipt in inputs] |
| 81 | + self.output_tensor_names = [out.name for out in outputs] |
| 82 | + assert not any(name in self.output_tensor_names for name in self.input_tensor_names) |
| 83 | + assert not any(name in self.input_tensor_names for name in self.output_tensor_names) |
| 84 | + |
| 85 | + @property |
| 86 | + def required_input_dataset_statistics(self) -> Dict[str, Set[Measure]]: |
| 87 | + return self._req_input_stats[DATASET] |
| 88 | + |
| 89 | + @property |
| 90 | + def required_output_dataset_statistics(self) -> Dict[str, Set[Measure]]: |
| 91 | + return self._req_output_stats[DATASET] |
| 92 | + |
| 93 | + @property |
| 94 | + def computed_dataset_statistics(self) -> Dict[str, Dict[Measure, Any]]: |
| 95 | + return self._computed_dataset_stats |
| 96 | + |
| 97 | + def apply_preprocessing( |
| 98 | + self, *input_tensors: xr.DataArray |
| 99 | + ) -> Tuple[List[xr.DataArray], Dict[str, Dict[Measure, Any]]]: |
| 100 | + assert len(input_tensors) == len(self.input_tensor_names) |
| 101 | + tensors = dict(zip(self.input_tensor_names, input_tensors)) |
| 102 | + sample_stats = self.compute_sample_statistics(tensors, self._req_input_stats[SAMPLE]) |
| 103 | + for proc in self._prep: |
| 104 | + proc.set_computed_sample_statistics(sample_stats) |
| 105 | + tensors[proc.tensor_name] = proc.apply(tensors[proc.tensor_name]) |
| 106 | + |
| 107 | + return [tensors[tn] for tn in self.input_tensor_names], sample_stats |
| 108 | + |
| 109 | + def apply_postprocessing( |
| 110 | + self, *output_tensors: xr.DataArray, input_sample_statistics: Dict[str, Dict[Measure, Any]] |
| 111 | + ) -> Tuple[List[xr.DataArray], Dict[str, Dict[Measure, Any]]]: |
| 112 | + assert len(output_tensors) == len(self.output_tensor_names) |
| 113 | + tensors = dict(zip(self.output_tensor_names, output_tensors)) |
| 114 | + sample_stats = input_sample_statistics |
| 115 | + sample_stats.update(self.compute_sample_statistics(tensors, self._req_output_stats[SAMPLE])) |
| 116 | + for proc in self._post: |
| 117 | + proc.set_computed_sample_statistics(sample_stats) |
| 118 | + tensors[proc.tensor_name] = proc.apply(tensors[proc.tensor_name]) |
| 119 | + |
| 120 | + return [tensors[tn] for tn in self.output_tensor_names], sample_stats |
| 121 | + |
| 122 | + def set_computed_dataset_statistics(self, computed: Dict[str, Dict[Measure, Any]]): |
| 123 | + """ |
| 124 | + This method sets the externally computed dataset statistics. |
| 125 | + Which statistics are expected is specified by the `required_dataset_statistics` property. |
| 126 | + """ |
| 127 | + # always expect input tensor statistics |
| 128 | + for tensor_name, req_measures in self.required_input_dataset_statistics: |
| 129 | + comp_measures = computed.get(tensor_name, {}) |
| 130 | + for req_measure in req_measures: |
| 131 | + if req_measure not in comp_measures: |
| 132 | + raise ValueError(f"Missing required measure {req_measure} for input tensor {tensor_name}") |
| 133 | + |
| 134 | + # as output tensor statistics may initially not be available, we only warn about their absence |
| 135 | + output_statistics_missing = False |
| 136 | + for tensor_name, req_measures in self.required_output_dataset_statistics: |
| 137 | + comp_measures = computed.get(tensor_name, {}) |
| 138 | + for req_measure in req_measures: |
| 139 | + if req_measure not in comp_measures: |
| 140 | + output_statistics_missing = True |
| 141 | + warnings.warn(f"Missing required measure {req_measure} for output tensor {tensor_name}") |
| 142 | + |
| 143 | + self._computed_dataset_stats = computed |
| 144 | + |
| 145 | + # set dataset statistics for each processing step |
| 146 | + for proc in self._prep: |
| 147 | + proc.set_computed_dataset_statistics(self.computed_dataset_statistics) |
| 148 | + |
| 149 | + @classmethod |
| 150 | + def compute_sample_statistics( |
| 151 | + cls, tensors: Dict[str, xr.DataArray], measures: Dict[str, Set[Measure]] |
| 152 | + ) -> Dict[str, Dict[Measure, Any]]: |
| 153 | + return {tname: cls._compute_tensor_statistics(tensors[tname], ms) for tname, ms in measures.items()} |
| 154 | + |
| 155 | + @staticmethod |
| 156 | + def _compute_tensor_statistics(tensor: xr.DataArray, measures: Set[Measure]) -> Dict[Measure, Any]: |
| 157 | + ret = {} |
| 158 | + for measure in measures: |
| 159 | + if isinstance(measure, Mean): |
| 160 | + v = tensor.mean(dim=measure.axes) |
| 161 | + elif isinstance(measure, Std): |
| 162 | + v = tensor.std(dim=measure.axes) |
| 163 | + elif isinstance(measure, Percentile): |
| 164 | + v = tensor.quantile(measure.n / 100.0, dim=measure.axes) |
| 165 | + else: |
| 166 | + raise NotImplementedError(measure) |
| 167 | + |
| 168 | + ret[measure] = v |
| 169 | + |
| 170 | + return ret |
| 171 | + |
| 172 | + @staticmethod |
| 173 | + def _collect_required_stats(proc: Sequence[Processing], scope: Literal["sample", "dataset"]): |
| 174 | + stats = defaultdict(set) |
| 175 | + for p in proc: |
| 176 | + if scope == SAMPLE: |
| 177 | + req = p.get_required_sample_statistics() |
| 178 | + elif scope == DATASET: |
| 179 | + req = p.get_required_dataset_statistics() |
| 180 | + else: |
| 181 | + raise ValueError(scope) |
| 182 | + for tn, ms in req.items(): |
| 183 | + stats[tn].update(ms) |
| 184 | + |
| 185 | + return dict(stats) |
0 commit comments