Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 54 additions & 27 deletions bioimageio/core/proc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from typing_extensions import Self, assert_never

from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import TensorId
from bioimageio.spec.model.v0_5 import (
_convert_proc, # pyright: ignore [reportPrivateUsage]
)

from ._op_base import BlockedOperator, Operator
from .axis import AxisId, PerAxis
Expand Down Expand Up @@ -656,35 +660,58 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor:
ZeroMeanUnitVariance,
]

def proc_descr_v4_to_op(
inp: Union[v0_4.InputTensorDescr, v0_4.OutputTensorDescr],
proc_spec: Union[v0_4.PreprocessingDescr, v0_4.PostprocessingDescr],
) -> Processing:
member_id = TensorId(str(inp.name))
if isinstance(proc_spec, v0_4.BinarizeDescr):
return Binarize.from_proc_descr(proc_spec, member_id)
if isinstance(proc_spec, v0_4.ScaleMeanVarianceDescr):
return ScaleMeanVariance.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ClipDescr):
return Clip.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ScaleLinearDescr):
return ScaleLinear.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ScaleRangeDescr):
return ScaleRange.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.SigmoidDescr):
return Sigmoid.from_proc_descr(proc_spec, member_id)
elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr):
if proc_spec.kwargs.mode == "fixed":
axes = inp.axes
v5_proc_spec = _convert_proc(proc_spec, axes)
assert isinstance(
v5_proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr
) # FIXME
return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_spec, member_id)
else:
return ZeroMeanUnitVariance.from_proc_descr(proc_spec, member_id)
else:
assert_never(proc_spec)


def get_proc_class(proc_spec: ProcDescr):
if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
return Binarize
elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)):
return Clip
def proc_descr_v5_to_op(
inp: Union[v0_5.InputTensorDescr, v0_5.OutputTensorDescr],
proc_spec: Union[v0_5.PreprocessingDescr, v0_5.PostprocessingDescr],
) -> Processing:
if isinstance(proc_spec, v0_5.BinarizeDescr):
return Binarize.from_proc_descr(proc_spec, inp.id)
if isinstance(proc_spec, v0_5.ScaleMeanVarianceDescr):
return ScaleMeanVariance.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ClipDescr):
return Clip.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ScaleLinearDescr):
return ScaleLinear.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ScaleRangeDescr):
return ScaleRange.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.SigmoidDescr):
return Sigmoid.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.EnsureDtypeDescr):
return EnsureDtype
return EnsureDtype.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.ZeroMeanUnitVarianceDescr):
return ZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id)
elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr):
return FixedZeroMeanUnitVariance
elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
return ScaleLinear
elif isinstance(
proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
):
return ScaleMeanVariance
elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
return ScaleRange
elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
return Sigmoid
elif (
isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr)
and proc_spec.kwargs.mode == "fixed"
):
return FixedZeroMeanUnitVariance
elif isinstance(
proc_spec,
(v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
):
return ZeroMeanUnitVariance
return FixedZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id)
else:
assert_never(proc_spec)
123 changes: 63 additions & 60 deletions bioimageio/core/proc_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
from bioimageio.spec.model.v0_5 import TensorId

from .digest_spec import get_member_ids
from .proc_ops import (
AddKnownDatasetStats,
EnsureDtype,
Processing,
UpdateStats,
get_proc_class,
proc_descr_v4_to_op,
proc_descr_v5_to_op,
)
from .sample import Sample
from .stat_calculators import StatsCalculator
Expand Down Expand Up @@ -136,65 +137,67 @@ def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures
)


def _prepare_v4_preprocs(
tensor_descrs: Sequence[v0_4.InputTensorDescr],
) -> List[Processing]:
procs: List[Processing] = []
for t_descr in tensor_descrs:
member_id = TensorId(str(t_descr.name))
ensure_dtype = EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
procs.append(ensure_dtype)
for proc_d in t_descr.preprocessing:
procs.append(proc_descr_v4_to_op(t_descr, proc_d))
procs.append(ensure_dtype)
return procs


def _prepare_v4_postprocs(
tensor_descrs: Sequence[v0_4.OutputTensorDescr],
) -> List[Processing]:
procs: List[Processing] = []
for t_descr in tensor_descrs:
member_id = TensorId(str(t_descr.name))
ensure_dtype = EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
procs.append(ensure_dtype)
for proc_d in t_descr.postprocessing:
procs.append(proc_descr_v4_to_op(t_descr, proc_d))
procs.append(ensure_dtype)
return procs


def _prepare_v5_preprocs(
tensor_descrs: Sequence[v0_5.InputTensorDescr],
) -> List[Processing]:
procs: List[Processing] = []
for t_descr in tensor_descrs:
for proc_d in t_descr.preprocessing:
procs.append(proc_descr_v5_to_op(t_descr, proc_d))
return procs


def _prepare_v5_postprocs(
tensor_descrs: Sequence[v0_5.OutputTensorDescr],
) -> List[Processing]:
procs: List[Processing] = []
for t_descr in tensor_descrs:
for proc_d in t_descr.postprocessing:
procs.append(proc_descr_v5_to_op(t_descr, proc_d))
return procs


def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing:
pre_measures: Set[Measure] = set()
post_measures: Set[Measure] = set()

input_ids = set(get_member_ids(model.inputs))
output_ids = set(get_member_ids(model.outputs))

def prepare_procs(tensor_descrs: Sequence[TensorDescr]):
procs: List[Processing] = []
for t_descr in tensor_descrs:
if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
proc_descrs: List[
Union[
v0_4.PreprocessingDescr,
v0_5.PreprocessingDescr,
v0_4.PostprocessingDescr,
v0_5.PostprocessingDescr,
]
] = list(t_descr.preprocessing)
elif isinstance(
t_descr,
(v0_4.OutputTensorDescr, v0_5.OutputTensorDescr),
):
proc_descrs = list(t_descr.postprocessing)
else:
assert_never(t_descr)

if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
ensure_dtype = v0_5.EnsureDtypeDescr(
kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type)
)
if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs:
proc_descrs.insert(0, ensure_dtype)

proc_descrs.append(ensure_dtype)

for proc_d in proc_descrs:
proc_class = get_proc_class(proc_d)
member_id = (
TensorId(str(t_descr.name))
if isinstance(t_descr, v0_4.TensorDescrBase)
else t_descr.id
)
req = proc_class.from_proc_descr(
proc_d, member_id # pyright: ignore[reportArgumentType]
)
for m in req.required_measures:
if m.member_id in input_ids:
pre_measures.add(m)
elif m.member_id in output_ids:
post_measures.add(m)
else:
raise ValueError("When to raise ")
procs.append(req)
return procs
if isinstance(model, v0_4.ModelDescr):
pre = _prepare_v4_preprocs(model.inputs)
post = _prepare_v4_postprocs(model.outputs)
elif isinstance(model, v0_5.ModelDescr):
pre = _prepare_v5_preprocs(model.inputs)
post = _prepare_v5_postprocs(model.outputs)
else:
assert_never(model)

return _SetupProcessing(
pre=prepare_procs(model.inputs),
post=prepare_procs(model.outputs),
pre_measures=pre_measures,
post_measures=post_measures,
pre=pre,
post=post,
pre_measures={m for proc in pre for m in proc.required_measures},
post_measures={m for proc in post for m in proc.required_measures},
)
Loading