diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index eecf47b1..49ea2da6 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -17,6 +17,10 @@ from typing_extensions import Self, assert_never from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import TensorId +from bioimageio.spec.model.v0_5 import ( + _convert_proc, # pyright: ignore [reportPrivateUsage] +) from ._op_base import BlockedOperator, Operator from .axis import AxisId, PerAxis @@ -656,35 +660,58 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: ZeroMeanUnitVariance, ] +def proc_descr_v4_to_op( + inp: Union[v0_4.InputTensorDescr, v0_4.OutputTensorDescr], + proc_spec: Union[v0_4.PreprocessingDescr, v0_4.PostprocessingDescr], +) -> Processing: + member_id = TensorId(str(inp.name)) + if isinstance(proc_spec, v0_4.BinarizeDescr): + return Binarize.from_proc_descr(proc_spec, member_id) + if isinstance(proc_spec, v0_4.ScaleMeanVarianceDescr): + return ScaleMeanVariance.from_proc_descr(proc_spec, member_id) + elif isinstance(proc_spec, v0_4.ClipDescr): + return Clip.from_proc_descr(proc_spec, member_id) + elif isinstance(proc_spec, v0_4.ScaleLinearDescr): + return ScaleLinear.from_proc_descr(proc_spec, member_id) + elif isinstance(proc_spec, v0_4.ScaleRangeDescr): + return ScaleRange.from_proc_descr(proc_spec, member_id) + elif isinstance(proc_spec, v0_4.SigmoidDescr): + return Sigmoid.from_proc_descr(proc_spec, member_id) + elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr): + if proc_spec.kwargs.mode == "fixed": + axes = inp.axes + v5_proc_spec = _convert_proc(proc_spec, axes) + assert isinstance( + v5_proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr + ) # FIXME + return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_spec, member_id) + else: + return ZeroMeanUnitVariance.from_proc_descr(proc_spec, member_id) + else: + assert_never(proc_spec) + -def get_proc_class(proc_spec: ProcDescr): - if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): - return Binarize - elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): - return Clip +def proc_descr_v5_to_op( + inp: Union[v0_5.InputTensorDescr, v0_5.OutputTensorDescr], + proc_spec: Union[v0_5.PreprocessingDescr, v0_5.PostprocessingDescr], +) -> Processing: + if isinstance(proc_spec, v0_5.BinarizeDescr): + return Binarize.from_proc_descr(proc_spec, inp.id) + if isinstance(proc_spec, v0_5.ScaleMeanVarianceDescr): + return ScaleMeanVariance.from_proc_descr(proc_spec, inp.id) + elif isinstance(proc_spec, v0_5.ClipDescr): + return Clip.from_proc_descr(proc_spec, inp.id) + elif isinstance(proc_spec, v0_5.ScaleLinearDescr): + return ScaleLinear.from_proc_descr(proc_spec, inp.id) + elif isinstance(proc_spec, v0_5.ScaleRangeDescr): + return ScaleRange.from_proc_descr(proc_spec, inp.id) + elif isinstance(proc_spec, v0_5.SigmoidDescr): + return Sigmoid.from_proc_descr(proc_spec, inp.id) elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): - return EnsureDtype + return EnsureDtype.from_proc_descr(proc_spec, inp.id) + elif isinstance(proc_spec, v0_5.ZeroMeanUnitVarianceDescr): + return ZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id) elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): - return FixedZeroMeanUnitVariance - elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): - return ScaleLinear - elif isinstance( - proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) - ): - return ScaleMeanVariance - elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): - return ScaleRange - elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): - return Sigmoid - elif ( - isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) - and proc_spec.kwargs.mode == "fixed" - ): - return FixedZeroMeanUnitVariance - elif isinstance( - proc_spec, - (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), - ): - return ZeroMeanUnitVariance + return FixedZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id) else: assert_never(proc_spec) diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index b9afb711..24a70b60 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -14,12 +14,13 @@ from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_5 import TensorId -from .digest_spec import get_member_ids from .proc_ops import ( AddKnownDatasetStats, + EnsureDtype, Processing, UpdateStats, - get_proc_class, + proc_descr_v4_to_op, + proc_descr_v5_to_op, ) from .sample import Sample from .stat_calculators import StatsCalculator @@ -136,65 +137,67 @@ def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures ) +def _prepare_v4_preprocs( + tensor_descrs: Sequence[v0_4.InputTensorDescr], +) -> List[Processing]: + procs: List[Processing] = [] + for t_descr in tensor_descrs: + member_id = TensorId(str(t_descr.name)) + ensure_dtype = EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type) + procs.append(ensure_dtype) + for proc_d in t_descr.preprocessing: + procs.append(proc_descr_v4_to_op(t_descr, proc_d)) + procs.append(ensure_dtype) + return procs + + +def _prepare_v4_postprocs( + tensor_descrs: Sequence[v0_4.OutputTensorDescr], +) -> List[Processing]: + procs: List[Processing] = [] + for t_descr in tensor_descrs: + member_id = TensorId(str(t_descr.name)) + ensure_dtype = EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type) + procs.append(ensure_dtype) + for proc_d in t_descr.postprocessing: + procs.append(proc_descr_v4_to_op(t_descr, proc_d)) + procs.append(ensure_dtype) + return procs + + +def _prepare_v5_preprocs( + tensor_descrs: Sequence[v0_5.InputTensorDescr], +) -> List[Processing]: + procs: List[Processing] = [] + for t_descr in tensor_descrs: + for proc_d in t_descr.preprocessing: + procs.append(proc_descr_v5_to_op(t_descr, proc_d)) + return procs + + +def _prepare_v5_postprocs( + tensor_descrs: Sequence[v0_5.OutputTensorDescr], +) -> List[Processing]: + procs: List[Processing] = [] + for t_descr in tensor_descrs: + for proc_d in t_descr.postprocessing: + procs.append(proc_descr_v5_to_op(t_descr, proc_d)) + return procs + + def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing: - pre_measures: Set[Measure] = set() - post_measures: Set[Measure] = set() - - input_ids = set(get_member_ids(model.inputs)) - output_ids = set(get_member_ids(model.outputs)) - - def prepare_procs(tensor_descrs: Sequence[TensorDescr]): - procs: List[Processing] = [] - for t_descr in tensor_descrs: - if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): - proc_descrs: List[ - Union[ - v0_4.PreprocessingDescr, - v0_5.PreprocessingDescr, - v0_4.PostprocessingDescr, - v0_5.PostprocessingDescr, - ] - ] = list(t_descr.preprocessing) - elif isinstance( - t_descr, - (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr), - ): - proc_descrs = list(t_descr.postprocessing) - else: - assert_never(t_descr) - - if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): - ensure_dtype = v0_5.EnsureDtypeDescr( - kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type) - ) - if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs: - proc_descrs.insert(0, ensure_dtype) - - proc_descrs.append(ensure_dtype) - - for proc_d in proc_descrs: - proc_class = get_proc_class(proc_d) - member_id = ( - TensorId(str(t_descr.name)) - if isinstance(t_descr, v0_4.TensorDescrBase) - else t_descr.id - ) - req = proc_class.from_proc_descr( - proc_d, member_id # pyright: ignore[reportArgumentType] - ) - for m in req.required_measures: - if m.member_id in input_ids: - pre_measures.add(m) - elif m.member_id in output_ids: - post_measures.add(m) - else: - raise ValueError("When to raise ") - procs.append(req) - return procs + if isinstance(model, v0_4.ModelDescr): + pre = _prepare_v4_preprocs(model.inputs) + post = _prepare_v4_postprocs(model.outputs) + elif isinstance(model, v0_5.ModelDescr): + pre = _prepare_v5_preprocs(model.inputs) + post = _prepare_v5_postprocs(model.outputs) + else: + assert_never(model) return _SetupProcessing( - pre=prepare_procs(model.inputs), - post=prepare_procs(model.outputs), - pre_measures=pre_measures, - post_measures=post_measures, + pre=pre, + post=post, + pre_measures={m for proc in pre for m in proc.required_measures}, + post_measures={m for proc in post for m in proc.required_measures}, )