|
16 | 16 | import xarray as xr |
17 | 17 | from typing_extensions import Self, assert_never |
18 | 18 |
|
| 19 | +from bioimageio.core.digest_spec import get_member_id |
19 | 20 | from bioimageio.spec.model import v0_4, v0_5 |
20 | | -from bioimageio.spec.model.v0_5 import TensorId |
21 | 21 | from bioimageio.spec.model.v0_5 import ( |
22 | 22 | _convert_proc, # pyright: ignore [reportPrivateUsage] |
23 | 23 | ) |
@@ -672,144 +672,53 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: |
672 | 672 | ] |
673 | 673 |
|
674 | 674 |
|
675 | | -def get_proc_class(proc_spec: ProcDescr): |
676 | | - if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): |
677 | | - return Binarize |
678 | | - elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): |
679 | | - return Clip |
680 | | - elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): |
681 | | - return EnsureDtype |
682 | | - elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): |
683 | | - return FixedZeroMeanUnitVariance |
684 | | - elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): |
685 | | - return ScaleLinear |
| 675 | +def get_proc( |
| 676 | + proc_descr: ProcDescr, |
| 677 | + tensor_descr: Union[ |
| 678 | + v0_4.InputTensorDescr, |
| 679 | + v0_4.OutputTensorDescr, |
| 680 | + v0_5.InputTensorDescr, |
| 681 | + v0_5.OutputTensorDescr, |
| 682 | + ], |
| 683 | +) -> Processing: |
| 684 | + member_id = get_member_id(tensor_descr) |
| 685 | + |
| 686 | + if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): |
| 687 | + return Binarize.from_proc_descr(proc_descr, member_id) |
| 688 | + elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): |
| 689 | + return Clip.from_proc_descr(proc_descr, member_id) |
| 690 | + elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): |
| 691 | + return EnsureDtype.from_proc_descr(proc_descr, member_id) |
| 692 | + elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): |
| 693 | + return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) |
| 694 | + elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): |
| 695 | + return ScaleLinear.from_proc_descr(proc_descr, member_id) |
686 | 696 | elif isinstance( |
687 | | - proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) |
| 697 | + proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) |
688 | 698 | ): |
689 | | - return ScaleMeanVariance |
690 | | - elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): |
691 | | - return ScaleRange |
692 | | - elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): |
693 | | - return Sigmoid |
| 699 | + return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) |
| 700 | + elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): |
| 701 | + return ScaleRange.from_proc_descr(proc_descr, member_id) |
| 702 | + elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): |
| 703 | + return Sigmoid.from_proc_descr(proc_descr, member_id) |
694 | 704 | elif ( |
695 | | - isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) |
696 | | - and proc_spec.kwargs.mode == "fixed" |
| 705 | + isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) |
| 706 | + and proc_descr.kwargs.mode == "fixed" |
697 | 707 | ): |
698 | | - return FixedZeroMeanUnitVariance |
| 708 | + if not isinstance( |
| 709 | + tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) |
| 710 | + ): |
| 711 | + raise TypeError( |
| 712 | + "Expected v0_4 tensor description for v0_4 processing description" |
| 713 | + ) |
| 714 | + |
| 715 | + v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) |
| 716 | + assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) |
| 717 | + return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) |
699 | 718 | elif isinstance( |
700 | | - proc_spec, |
| 719 | + proc_descr, |
701 | 720 | (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), |
702 | 721 | ): |
703 | | - return ZeroMeanUnitVariance |
704 | | - else: |
705 | | - assert_never(proc_spec) |
706 | | - |
707 | | - |
708 | | -def preproc_v4_to_processing( |
709 | | - inp: v0_4.InputTensorDescr, |
710 | | - proc_spec: v0_4.PreprocessingDescr, |
711 | | -) -> Processing: |
712 | | - member_id = TensorId(str(inp.name)) |
713 | | - if isinstance(proc_spec, v0_4.BinarizeDescr): |
714 | | - return Binarize.from_proc_descr(proc_spec, member_id) |
715 | | - elif isinstance(proc_spec, v0_4.ClipDescr): |
716 | | - return Clip.from_proc_descr(proc_spec, member_id) |
717 | | - elif isinstance(proc_spec, v0_4.ScaleLinearDescr): |
718 | | - return ScaleLinear.from_proc_descr(proc_spec, member_id) |
719 | | - elif isinstance(proc_spec, v0_4.ScaleRangeDescr): |
720 | | - return ScaleRange.from_proc_descr(proc_spec, member_id) |
721 | | - elif isinstance(proc_spec, v0_4.SigmoidDescr): |
722 | | - return Sigmoid.from_proc_descr(proc_spec, member_id) |
723 | | - elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr): |
724 | | - if proc_spec.kwargs.mode == "fixed": |
725 | | - axes = inp.axes |
726 | | - v5_proc_spec = _convert_proc(proc_spec, axes) |
727 | | - assert isinstance( |
728 | | - v5_proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr |
729 | | - ) # FIXME |
730 | | - return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_spec, member_id) |
731 | | - else: |
732 | | - return ZeroMeanUnitVariance.from_proc_descr(proc_spec, member_id) |
733 | | - else: |
734 | | - assert_never(proc_spec) |
735 | | - |
736 | | - |
737 | | -def postproc_v4_to_processing( |
738 | | - inp: v0_4.OutputTensorDescr, |
739 | | - proc_spec: v0_4.PostprocessingDescr, |
740 | | -) -> Processing: |
741 | | - member_id = TensorId(str(inp.name)) |
742 | | - if isinstance(proc_spec, v0_4.BinarizeDescr): |
743 | | - return Binarize.from_proc_descr(proc_spec, member_id) |
744 | | - if isinstance(proc_spec, v0_4.ScaleMeanVarianceDescr): |
745 | | - return ScaleMeanVariance.from_proc_descr(proc_spec, member_id) |
746 | | - elif isinstance(proc_spec, v0_4.ClipDescr): |
747 | | - return Clip.from_proc_descr(proc_spec, member_id) |
748 | | - elif isinstance(proc_spec, v0_4.ScaleLinearDescr): |
749 | | - return ScaleLinear.from_proc_descr(proc_spec, member_id) |
750 | | - elif isinstance(proc_spec, v0_4.ScaleRangeDescr): |
751 | | - return ScaleRange.from_proc_descr(proc_spec, member_id) |
752 | | - elif isinstance(proc_spec, v0_4.SigmoidDescr): |
753 | | - return Sigmoid.from_proc_descr(proc_spec, member_id) |
754 | | - elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr): |
755 | | - if proc_spec.kwargs.mode == "fixed": |
756 | | - axes = inp.axes |
757 | | - v5_proc_spec = _convert_proc(proc_spec, axes) |
758 | | - assert isinstance( |
759 | | - v5_proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr |
760 | | - ) # FIXME |
761 | | - return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_spec, member_id) |
762 | | - else: |
763 | | - return ZeroMeanUnitVariance.from_proc_descr(proc_spec, member_id) |
764 | | - else: |
765 | | - assert_never(proc_spec) |
766 | | - |
767 | | - |
768 | | -def preproc_v5_to_processing( |
769 | | - inp: v0_5.InputTensorDescr, |
770 | | - proc_spec: v0_5.PreprocessingDescr, |
771 | | -) -> Processing: |
772 | | - if isinstance(proc_spec, v0_5.BinarizeDescr): |
773 | | - return Binarize.from_proc_descr(proc_spec, inp.id) |
774 | | - elif isinstance(proc_spec, v0_5.ClipDescr): |
775 | | - return Clip.from_proc_descr(proc_spec, inp.id) |
776 | | - elif isinstance(proc_spec, v0_5.ScaleLinearDescr): |
777 | | - return ScaleLinear.from_proc_descr(proc_spec, inp.id) |
778 | | - elif isinstance(proc_spec, v0_5.ScaleRangeDescr): |
779 | | - return ScaleRange.from_proc_descr(proc_spec, inp.id) |
780 | | - elif isinstance(proc_spec, v0_5.SigmoidDescr): |
781 | | - return Sigmoid.from_proc_descr(proc_spec, inp.id) |
782 | | - elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): |
783 | | - return EnsureDtype.from_proc_descr(proc_spec, inp.id) |
784 | | - elif isinstance(proc_spec, v0_5.ZeroMeanUnitVarianceDescr): |
785 | | - return ZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id) |
786 | | - elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): |
787 | | - return FixedZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id) |
788 | | - else: |
789 | | - assert_never(proc_spec) |
790 | | - |
791 | | - |
792 | | -def postproc_v5_to_processing( |
793 | | - inp: v0_5.OutputTensorDescr, |
794 | | - proc_spec: v0_5.PostprocessingDescr, |
795 | | -) -> Processing: |
796 | | - if isinstance(proc_spec, v0_5.BinarizeDescr): |
797 | | - return Binarize.from_proc_descr(proc_spec, inp.id) |
798 | | - if isinstance(proc_spec, v0_5.ScaleMeanVarianceDescr): |
799 | | - return ScaleMeanVariance.from_proc_descr(proc_spec, inp.id) |
800 | | - elif isinstance(proc_spec, v0_5.ClipDescr): |
801 | | - return Clip.from_proc_descr(proc_spec, inp.id) |
802 | | - elif isinstance(proc_spec, v0_5.ScaleLinearDescr): |
803 | | - return ScaleLinear.from_proc_descr(proc_spec, inp.id) |
804 | | - elif isinstance(proc_spec, v0_5.ScaleRangeDescr): |
805 | | - return ScaleRange.from_proc_descr(proc_spec, inp.id) |
806 | | - elif isinstance(proc_spec, v0_5.SigmoidDescr): |
807 | | - return Sigmoid.from_proc_descr(proc_spec, inp.id) |
808 | | - elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): |
809 | | - return EnsureDtype.from_proc_descr(proc_spec, inp.id) |
810 | | - elif isinstance(proc_spec, v0_5.ZeroMeanUnitVarianceDescr): |
811 | | - return ZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id) |
812 | | - elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): |
813 | | - return FixedZeroMeanUnitVariance.from_proc_descr(proc_spec, inp.id) |
| 722 | + return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) |
814 | 723 | else: |
815 | | - assert_never(proc_spec) |
| 724 | + assert_never(proc_descr) |
0 commit comments