Skip to content

Commit e6bd24b

Browse files
committed
Separates pipeline build logic per version
1 parent 9b08015 commit e6bd24b

File tree

2 files changed

+150
-56
lines changed

2 files changed

+150
-56
lines changed

bioimageio/core/proc_ops.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dataclasses import InitVar, dataclass, field
44
from typing import (
55
Collection,
6+
List,
67
Literal,
78
Mapping,
89
Optional,
@@ -17,6 +18,7 @@
1718
from typing_extensions import Self, assert_never
1819

1920
from bioimageio.spec.model import v0_4, v0_5
21+
from bioimageio.spec.model.v0_5 import TensorId
2022

2123
from ._op_base import BlockedOperator, Operator
2224
from .axis import AxisId, PerAxis
@@ -688,3 +690,95 @@ def get_proc_class(proc_spec: ProcDescr):
688690
return ZeroMeanUnitVariance
689691
else:
690692
assert_never(proc_spec)
693+
694+
def preproc_v4_to_processing(inp: v0_4.InputTensorDescr, proc_spec: v0_4.PreprocessingDescr,) -> Processing:
695+
from bioimageio.spec.model.v0_5 import _convert_proc # pyright: ignore [reportPrivateUsage]
696+
member_id = TensorId(str(inp.name))
697+
if isinstance(proc_spec, v0_4.BinarizeDescr):
698+
return Binarize.from_proc_descr(proc_spec, member_id)
699+
elif isinstance(proc_spec, v0_4.ClipDescr):
700+
return Clip.from_proc_descr(proc_spec, member_id)
701+
elif isinstance(proc_spec, v0_4.ScaleLinearDescr):
702+
return ScaleLinear.from_proc_descr(proc_spec, member_id)
703+
elif isinstance(proc_spec, v0_4.ScaleRangeDescr):
704+
return ScaleRange.from_proc_descr(proc_spec, member_id)
705+
elif isinstance(proc_spec, v0_4.SigmoidDescr):
706+
return Sigmoid.from_proc_descr(proc_spec, member_id)
707+
elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr):
708+
if proc_spec.kwargs.mode == "fixed":
709+
axes = inp.axes
710+
v5_proc_spec = _convert_proc(proc_spec, axes)
711+
assert isinstance(v5_proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr) #FIXME
712+
return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_spec, member_id)
713+
else:
714+
return ZeroMeanUnitVariance.from_proc_descr(proc_spec, member_id)
715+
else:
716+
assert_never(proc_spec)
717+
718+
def postproc_v4_to_processing(inp: v0_4.OutputTensorDescr, proc_spec: v0_4.PostprocessingDescr,) -> Processing:
719+
from bioimageio.spec.model.v0_5 import _convert_proc # pyright: ignore [reportPrivateUsage]
720+
member_id = TensorId(str(inp.name))
721+
if isinstance(proc_spec, v0_4.BinarizeDescr):
722+
return Binarize.from_proc_descr(proc_spec, member_id)
723+
if isinstance(proc_spec, v0_4.ScaleMeanVarianceDescr):
724+
return ScaleMeanVariance.from_proc_descr(proc_spec, member_id)
725+
elif isinstance(proc_spec, v0_4.ClipDescr):
726+
return Clip.from_proc_descr(proc_spec, member_id)
727+
elif isinstance(proc_spec, v0_4.ScaleLinearDescr):
728+
return ScaleLinear.from_proc_descr(proc_spec, member_id)
729+
elif isinstance(proc_spec, v0_4.ScaleRangeDescr):
730+
return ScaleRange.from_proc_descr(proc_spec, member_id)
731+
elif isinstance(proc_spec, v0_4.SigmoidDescr):
732+
return Sigmoid.from_proc_descr(proc_spec, member_id)
733+
elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr):
734+
if proc_spec.kwargs.mode == "fixed":
735+
axes = inp.axes
736+
v5_proc_spec = _convert_proc(proc_spec, axes)
737+
assert isinstance(v5_proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr) #FIXME
738+
return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_spec, member_id)
739+
else:
740+
return ZeroMeanUnitVariance.from_proc_descr(proc_spec, member_id)
741+
else:
742+
assert_never(proc_spec)
743+
744+
def preproc_v5_to_processing(inp: v0_5.InputTensorDescr, proc_spec: v0_5.PreprocessingDescr,) -> Processing:
745+
if isinstance(proc_spec, v0_5.BinarizeDescr):
746+
return Binarize.from_proc_descr(proc_spec, inp.id)
747+
elif isinstance(proc_spec, v0_5.ClipDescr):
748+
return Clip.from_proc_descr(proc_spec, inp.id)
749+
elif isinstance(proc_spec, v0_5.ScaleLinearDescr):
750+
return ScaleLinear.from_proc_descr(proc_spec, inp.id)
751+
elif isinstance(proc_spec, v0_5.ScaleRangeDescr):
752+
return ScaleRange.from_proc_descr(proc_spec, inp.id)
753+
elif isinstance(proc_spec, v0_5.SigmoidDescr):
754+
return Sigmoid.from_proc_descr(proc_spec, inp.id)
755+
elif isinstance(proc_spec, v0_5.EnsureDtypeDescr):
756+
return EnsureDtype.from_proc_descr(proc_spec, inp.id)
757+
elif isinstance(proc_spec, v0_5.ZeroMeanUnitVarianceDescr):
758+
return ZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id)
759+
elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr):
760+
return FixedZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id)
761+
else:
762+
assert_never(proc_spec)
763+
764+
def postproc_v5_to_processing(inp: v0_5.OutputTensorDescr, proc_spec: v0_5.PostprocessingDescr,) -> Processing:
765+
if isinstance(proc_spec, v0_5.BinarizeDescr):
766+
return Binarize.from_proc_descr(proc_spec, inp.id)
767+
if isinstance(proc_spec, v0_5.ScaleMeanVarianceDescr):
768+
return ScaleMeanVariance.from_proc_descr(proc_spec, inp.id)
769+
elif isinstance(proc_spec, v0_5.ClipDescr):
770+
return Clip.from_proc_descr(proc_spec, inp.id)
771+
elif isinstance(proc_spec, v0_5.ScaleLinearDescr):
772+
return ScaleLinear.from_proc_descr(proc_spec, inp.id)
773+
elif isinstance(proc_spec, v0_5.ScaleRangeDescr):
774+
return ScaleRange.from_proc_descr(proc_spec, inp.id)
775+
elif isinstance(proc_spec, v0_5.SigmoidDescr):
776+
return Sigmoid.from_proc_descr(proc_spec, inp.id)
777+
elif isinstance(proc_spec, v0_5.EnsureDtypeDescr):
778+
return EnsureDtype.from_proc_descr(proc_spec, inp.id)
779+
elif isinstance(proc_spec, v0_5.ZeroMeanUnitVarianceDescr):
780+
return ZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id)
781+
elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr):
782+
return FixedZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id)
783+
else:
784+
assert_never(proc_spec)

bioimageio/core/proc_setup.py

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Optional,
77
Sequence,
88
Set,
9+
Tuple,
910
Union,
1011
)
1112

@@ -17,9 +18,14 @@
1718
from .digest_spec import get_member_ids
1819
from .proc_ops import (
1920
AddKnownDatasetStats,
21+
EnsureDtype,
2022
Processing,
2123
UpdateStats,
2224
get_proc_class,
25+
postproc_v4_to_processing,
26+
postproc_v5_to_processing,
27+
preproc_v4_to_processing,
28+
preproc_v5_to_processing,
2329
)
2430
from .sample import Sample
2531
from .stat_calculators import StatsCalculator
@@ -135,66 +141,60 @@ def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures
135141
{m for m in s.post_measures if isinstance(m, SampleMeasureBase)},
136142
)
137143

144+
def _prepare_v4_preprocs(tensor_descrs: Sequence[v0_4.InputTensorDescr]) -> Tuple[List[Processing], Set[Measure]]:
145+
procs: List[Processing] = []
146+
for t_descr in tensor_descrs:
147+
member_id = TensorId(str(t_descr.name))
148+
procs.append(
149+
EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
150+
)
151+
for proc_d in t_descr.preprocessing:
152+
procs.append(preproc_v4_to_processing(t_descr, proc_d))
153+
measures = {m for proc in procs for m in proc.required_measures}
154+
return (procs, measures)
155+
156+
def _prepare_v4_postprocs(tensor_descrs: Sequence[v0_4.OutputTensorDescr]) -> Tuple[List[Processing], Set[Measure]]:
157+
procs: List[Processing] = []
158+
for t_descr in tensor_descrs:
159+
member_id = TensorId(str(t_descr.name))
160+
procs.append(
161+
EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
162+
)
163+
for proc_d in t_descr.postprocessing:
164+
procs.append(postproc_v4_to_processing(t_descr, proc_d))
165+
measures = {m for proc in procs for m in proc.required_measures}
166+
return (procs, measures)
167+
168+
def _prepare_v5_preprocs(tensor_descrs: Sequence[v0_5.InputTensorDescr]) -> Tuple[List[Processing], Set[Measure]]:
169+
procs: List[Processing] = []
170+
for t_descr in tensor_descrs:
171+
for proc_d in t_descr.preprocessing:
172+
procs.append(preproc_v5_to_processing(t_descr, proc_d))
173+
measures = {m for proc in procs for m in proc.required_measures}
174+
return (procs, measures)
175+
176+
def _prepare_v5_postprocs(tensor_descrs: Sequence[v0_5.OutputTensorDescr]) -> Tuple[List[Processing], Set[Measure]]:
177+
procs: List[Processing] = []
178+
for t_descr in tensor_descrs:
179+
for proc_d in t_descr.postprocessing:
180+
procs.append(postproc_v5_to_processing(t_descr, proc_d))
181+
measures = {m for proc in procs for m in proc.required_measures}
182+
return (procs, measures)
183+
138184

139185
def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing:
140-
pre_measures: Set[Measure] = set()
141-
post_measures: Set[Measure] = set()
142-
143-
input_ids = set(get_member_ids(model.inputs))
144-
output_ids = set(get_member_ids(model.outputs))
145-
146-
def prepare_procs(tensor_descrs: Sequence[TensorDescr]):
147-
procs: List[Processing] = []
148-
for t_descr in tensor_descrs:
149-
if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
150-
proc_descrs: List[
151-
Union[
152-
v0_4.PreprocessingDescr,
153-
v0_5.PreprocessingDescr,
154-
v0_4.PostprocessingDescr,
155-
v0_5.PostprocessingDescr,
156-
]
157-
] = list(t_descr.preprocessing)
158-
elif isinstance(
159-
t_descr,
160-
(v0_4.OutputTensorDescr, v0_5.OutputTensorDescr),
161-
):
162-
proc_descrs = list(t_descr.postprocessing)
163-
else:
164-
assert_never(t_descr)
165-
166-
if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
167-
ensure_dtype = v0_5.EnsureDtypeDescr(
168-
kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type)
169-
)
170-
if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs:
171-
proc_descrs.insert(0, ensure_dtype)
172-
173-
proc_descrs.append(ensure_dtype)
174-
175-
for proc_d in proc_descrs:
176-
proc_class = get_proc_class(proc_d)
177-
member_id = (
178-
TensorId(str(t_descr.name))
179-
if isinstance(t_descr, v0_4.TensorDescrBase)
180-
else t_descr.id
181-
)
182-
req = proc_class.from_proc_descr(
183-
proc_d, member_id # pyright: ignore[reportArgumentType]
184-
)
185-
for m in req.required_measures:
186-
if m.member_id in input_ids:
187-
pre_measures.add(m)
188-
elif m.member_id in output_ids:
189-
post_measures.add(m)
190-
else:
191-
raise ValueError("When to raise ")
192-
procs.append(req)
193-
return procs
186+
if isinstance(model, v0_4.ModelDescr):
187+
pre, pre_measures = _prepare_v4_preprocs(model.inputs)
188+
post, post_measures = _prepare_v4_postprocs(model.outputs)
189+
elif isinstance(model, v0_5.ModelDescr):
190+
pre, pre_measures = _prepare_v5_preprocs(model.inputs)
191+
post, post_measures = _prepare_v5_postprocs(model.outputs)
192+
else:
193+
assert_never(model)
194194

195195
return _SetupProcessing(
196-
pre=prepare_procs(model.inputs),
197-
post=prepare_procs(model.outputs),
196+
pre=pre,
197+
post=post,
198198
pre_measures=pre_measures,
199199
post_measures=post_measures,
200200
)

0 commit comments

Comments
 (0)