Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions bioimageio/core/proc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import InitVar, dataclass, field
from typing import (
Collection,
List,
Literal,
Mapping,
Optional,
Expand All @@ -17,6 +18,7 @@
from typing_extensions import Self, assert_never

from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import TensorId

from ._op_base import BlockedOperator, Operator
from .axis import AxisId, PerAxis
Expand Down Expand Up @@ -688,3 +690,95 @@ def get_proc_class(proc_spec: ProcDescr):
return ZeroMeanUnitVariance
else:
assert_never(proc_spec)

def preproc_v4_to_processing(inp: v0_4.InputTensorDescr, proc_spec: v0_4.PreprocessingDescr,) -> Processing:
from bioimageio.spec.model.v0_5 import _convert_proc # pyright: ignore [reportPrivateUsage]
member_id = TensorId(str(inp.name))
if isinstance(proc_spec, v0_4.BinarizeDescr):
return Binarize.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ClipDescr):
return Clip.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ScaleLinearDescr):
return ScaleLinear.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ScaleRangeDescr):
return ScaleRange.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.SigmoidDescr):
return Sigmoid.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr):
if proc_spec.kwargs.mode == "fixed":
axes = inp.axes
v5_proc_spec = _convert_proc(proc_spec, axes)
assert isinstance(v5_proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr) #FIXME
return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_spec, member_id)
else:
return ZeroMeanUnitVariance.from_proc_descr(proc_spec, member_id)
else:
assert_never(proc_spec)

def postproc_v4_to_processing(inp: v0_4.OutputTensorDescr, proc_spec: v0_4.PostprocessingDescr,) -> Processing:
from bioimageio.spec.model.v0_5 import _convert_proc # pyright: ignore [reportPrivateUsage]
member_id = TensorId(str(inp.name))
if isinstance(proc_spec, v0_4.BinarizeDescr):
return Binarize.from_proc_descr(proc_spec, member_id)
if isinstance(proc_spec, v0_4.ScaleMeanVarianceDescr):
return ScaleMeanVariance.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ClipDescr):
return Clip.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ScaleLinearDescr):
return ScaleLinear.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ScaleRangeDescr):
return ScaleRange.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.SigmoidDescr):
return Sigmoid.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr):
if proc_spec.kwargs.mode == "fixed":
axes = inp.axes
v5_proc_spec = _convert_proc(proc_spec, axes)
assert isinstance(v5_proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr) #FIXME
return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_spec, member_id)
else:
return ZeroMeanUnitVariance.from_proc_descr(proc_spec, member_id)
else:
assert_never(proc_spec)

def preproc_v5_to_processing(inp: v0_5.InputTensorDescr, proc_spec: v0_5.PreprocessingDescr,) -> Processing:
if isinstance(proc_spec, v0_5.BinarizeDescr):
return Binarize.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ClipDescr):
return Clip.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ScaleLinearDescr):
return ScaleLinear.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ScaleRangeDescr):
return ScaleRange.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.SigmoidDescr):
return Sigmoid.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.EnsureDtypeDescr):
return EnsureDtype.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ZeroMeanUnitVarianceDescr):
return ZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr):
return FixedZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id)
else:
assert_never(proc_spec)

def postproc_v5_to_processing(inp: v0_5.OutputTensorDescr, proc_spec: v0_5.PostprocessingDescr,) -> Processing:
if isinstance(proc_spec, v0_5.BinarizeDescr):
return Binarize.from_proc_descr(proc_spec, inp.id)
if isinstance(proc_spec, v0_5.ScaleMeanVarianceDescr):
return ScaleMeanVariance.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ClipDescr):
return Clip.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ScaleLinearDescr):
return ScaleLinear.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ScaleRangeDescr):
return ScaleRange.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.SigmoidDescr):
return Sigmoid.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.EnsureDtypeDescr):
return EnsureDtype.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ZeroMeanUnitVarianceDescr):
return ZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr):
return FixedZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id)
else:
assert_never(proc_spec)
112 changes: 56 additions & 56 deletions bioimageio/core/proc_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Optional,
Sequence,
Set,
Tuple,
Union,
)

Expand All @@ -17,9 +18,14 @@
from .digest_spec import get_member_ids
from .proc_ops import (
AddKnownDatasetStats,
EnsureDtype,
Processing,
UpdateStats,
get_proc_class,
postproc_v4_to_processing,
postproc_v5_to_processing,
preproc_v4_to_processing,
preproc_v5_to_processing,
)
from .sample import Sample
from .stat_calculators import StatsCalculator
Expand Down Expand Up @@ -135,66 +141,60 @@ def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures
{m for m in s.post_measures if isinstance(m, SampleMeasureBase)},
)

def _prepare_v4_preprocs(tensor_descrs: Sequence[v0_4.InputTensorDescr]) -> Tuple[List[Processing], Set[Measure]]:
procs: List[Processing] = []
for t_descr in tensor_descrs:
member_id = TensorId(str(t_descr.name))
procs.append(
EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
)
for proc_d in t_descr.preprocessing:
procs.append(preproc_v4_to_processing(t_descr, proc_d))
measures = {m for proc in procs for m in proc.required_measures}
return (procs, measures)

def _prepare_v4_postprocs(tensor_descrs: Sequence[v0_4.OutputTensorDescr]) -> Tuple[List[Processing], Set[Measure]]:
procs: List[Processing] = []
for t_descr in tensor_descrs:
member_id = TensorId(str(t_descr.name))
procs.append(
EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
)
for proc_d in t_descr.postprocessing:
procs.append(postproc_v4_to_processing(t_descr, proc_d))
measures = {m for proc in procs for m in proc.required_measures}
return (procs, measures)

def _prepare_v5_preprocs(tensor_descrs: Sequence[v0_5.InputTensorDescr]) -> Tuple[List[Processing], Set[Measure]]:
procs: List[Processing] = []
for t_descr in tensor_descrs:
for proc_d in t_descr.preprocessing:
procs.append(preproc_v5_to_processing(t_descr, proc_d))
measures = {m for proc in procs for m in proc.required_measures}
return (procs, measures)

def _prepare_v5_postprocs(tensor_descrs: Sequence[v0_5.OutputTensorDescr]) -> Tuple[List[Processing], Set[Measure]]:
procs: List[Processing] = []
for t_descr in tensor_descrs:
for proc_d in t_descr.postprocessing:
procs.append(postproc_v5_to_processing(t_descr, proc_d))
measures = {m for proc in procs for m in proc.required_measures}
return (procs, measures)


def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing:
pre_measures: Set[Measure] = set()
post_measures: Set[Measure] = set()

input_ids = set(get_member_ids(model.inputs))
output_ids = set(get_member_ids(model.outputs))

def prepare_procs(tensor_descrs: Sequence[TensorDescr]):
procs: List[Processing] = []
for t_descr in tensor_descrs:
if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
proc_descrs: List[
Union[
v0_4.PreprocessingDescr,
v0_5.PreprocessingDescr,
v0_4.PostprocessingDescr,
v0_5.PostprocessingDescr,
]
] = list(t_descr.preprocessing)
elif isinstance(
t_descr,
(v0_4.OutputTensorDescr, v0_5.OutputTensorDescr),
):
proc_descrs = list(t_descr.postprocessing)
else:
assert_never(t_descr)

if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
ensure_dtype = v0_5.EnsureDtypeDescr(
kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type)
)
if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs:
proc_descrs.insert(0, ensure_dtype)

proc_descrs.append(ensure_dtype)

for proc_d in proc_descrs:
proc_class = get_proc_class(proc_d)
member_id = (
TensorId(str(t_descr.name))
if isinstance(t_descr, v0_4.TensorDescrBase)
else t_descr.id
)
req = proc_class.from_proc_descr(
proc_d, member_id # pyright: ignore[reportArgumentType]
)
for m in req.required_measures:
if m.member_id in input_ids:
pre_measures.add(m)
elif m.member_id in output_ids:
post_measures.add(m)
else:
raise ValueError("When to raise ")
procs.append(req)
return procs
if isinstance(model, v0_4.ModelDescr):
pre, pre_measures = _prepare_v4_preprocs(model.inputs)
post, post_measures = _prepare_v4_postprocs(model.outputs)
elif isinstance(model, v0_5.ModelDescr):
pre, pre_measures = _prepare_v5_preprocs(model.inputs)
post, post_measures = _prepare_v5_postprocs(model.outputs)
else:
assert_never(model)

return _SetupProcessing(
pre=prepare_procs(model.inputs),
post=prepare_procs(model.outputs),
pre=pre,
post=post,
pre_measures=pre_measures,
post_measures=post_measures,
)
Loading