@@ -660,72 +660,9 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor:
660660 ZeroMeanUnitVariance ,
661661]
662662
663-
664- def get_proc_class (proc_spec : ProcDescr ):
665- if isinstance (proc_spec , (v0_4 .BinarizeDescr , v0_5 .BinarizeDescr )):
666- return Binarize
667- elif isinstance (proc_spec , (v0_4 .ClipDescr , v0_5 .ClipDescr )):
668- return Clip
669- elif isinstance (proc_spec , v0_5 .EnsureDtypeDescr ):
670- return EnsureDtype
671- elif isinstance (proc_spec , v0_5 .FixedZeroMeanUnitVarianceDescr ):
672- return FixedZeroMeanUnitVariance
673- elif isinstance (proc_spec , (v0_4 .ScaleLinearDescr , v0_5 .ScaleLinearDescr )):
674- return ScaleLinear
675- elif isinstance (
676- proc_spec , (v0_4 .ScaleMeanVarianceDescr , v0_5 .ScaleMeanVarianceDescr )
677- ):
678- return ScaleMeanVariance
679- elif isinstance (proc_spec , (v0_4 .ScaleRangeDescr , v0_5 .ScaleRangeDescr )):
680- return ScaleRange
681- elif isinstance (proc_spec , (v0_4 .SigmoidDescr , v0_5 .SigmoidDescr )):
682- return Sigmoid
683- elif (
684- isinstance (proc_spec , v0_4 .ZeroMeanUnitVarianceDescr )
685- and proc_spec .kwargs .mode == "fixed"
686- ):
687- return FixedZeroMeanUnitVariance
688- elif isinstance (
689- proc_spec ,
690- (v0_4 .ZeroMeanUnitVarianceDescr , v0_5 .ZeroMeanUnitVarianceDescr ),
691- ):
692- return ZeroMeanUnitVariance
693- else :
694- assert_never (proc_spec )
695-
696-
697- def preproc_v4_to_processing (
698- inp : v0_4 .InputTensorDescr ,
699- proc_spec : v0_4 .PreprocessingDescr ,
700- ) -> Processing :
701- member_id = TensorId (str (inp .name ))
702- if isinstance (proc_spec , v0_4 .BinarizeDescr ):
703- return Binarize .from_proc_descr (proc_spec , member_id )
704- elif isinstance (proc_spec , v0_4 .ClipDescr ):
705- return Clip .from_proc_descr (proc_spec , member_id )
706- elif isinstance (proc_spec , v0_4 .ScaleLinearDescr ):
707- return ScaleLinear .from_proc_descr (proc_spec , member_id )
708- elif isinstance (proc_spec , v0_4 .ScaleRangeDescr ):
709- return ScaleRange .from_proc_descr (proc_spec , member_id )
710- elif isinstance (proc_spec , v0_4 .SigmoidDescr ):
711- return Sigmoid .from_proc_descr (proc_spec , member_id )
712- elif isinstance (proc_spec , v0_4 .ZeroMeanUnitVarianceDescr ):
713- if proc_spec .kwargs .mode == "fixed" :
714- axes = inp .axes
715- v5_proc_spec = _convert_proc (proc_spec , axes )
716- assert isinstance (
717- v5_proc_spec , v0_5 .FixedZeroMeanUnitVarianceDescr
718- ) # FIXME
719- return FixedZeroMeanUnitVariance .from_proc_descr (v5_proc_spec , member_id )
720- else :
721- return ZeroMeanUnitVariance .from_proc_descr (proc_spec , member_id )
722- else :
723- assert_never (proc_spec )
724-
725-
726- def postproc_v4_to_processing (
727- inp : v0_4 .OutputTensorDescr ,
728- proc_spec : v0_4 .PostprocessingDescr ,
663+ def proc_descr_v4_to_op (
664+ inp : Union [v0_4 .InputTensorDescr , v0_4 .OutputTensorDescr ],
665+ proc_spec : Union [v0_4 .PreprocessingDescr , v0_4 .PostprocessingDescr ],
729666) -> Processing :
730667 member_id = TensorId (str (inp .name ))
731668 if isinstance (proc_spec , v0_4 .BinarizeDescr ):
@@ -754,33 +691,9 @@ def postproc_v4_to_processing(
754691 assert_never (proc_spec )
755692
756693
757- def preproc_v5_to_processing (
758- inp : v0_5 .InputTensorDescr ,
759- proc_spec : v0_5 .PreprocessingDescr ,
760- ) -> Processing :
761- if isinstance (proc_spec , v0_5 .BinarizeDescr ):
762- return Binarize .from_proc_descr (proc_spec , inp .id )
763- elif isinstance (proc_spec , v0_5 .ClipDescr ):
764- return Clip .from_proc_descr (proc_spec , inp .id )
765- elif isinstance (proc_spec , v0_5 .ScaleLinearDescr ):
766- return ScaleLinear .from_proc_descr (proc_spec , inp .id )
767- elif isinstance (proc_spec , v0_5 .ScaleRangeDescr ):
768- return ScaleRange .from_proc_descr (proc_spec , inp .id )
769- elif isinstance (proc_spec , v0_5 .SigmoidDescr ):
770- return Sigmoid .from_proc_descr (proc_spec , inp .id )
771- elif isinstance (proc_spec , v0_5 .EnsureDtypeDescr ):
772- return EnsureDtype .from_proc_descr (proc_spec , inp .id )
773- elif isinstance (proc_spec , v0_5 .ZeroMeanUnitVarianceDescr ):
774- return ZeroMeanUnitVariance .from_proc_descr (proc_spec , inp .id )
775- elif isinstance (proc_spec , v0_5 .FixedZeroMeanUnitVarianceDescr ):
776- return FixedZeroMeanUnitVariance .from_proc_descr (proc_spec , inp .id )
777- else :
778- assert_never (proc_spec )
779-
780-
781- def postproc_v5_to_processing (
782- inp : v0_5 .OutputTensorDescr ,
783- proc_spec : v0_5 .PostprocessingDescr ,
694+ def proc_descr_v5_to_op (
695+ inp : Union [v0_5 .InputTensorDescr , v0_5 .OutputTensorDescr ],
696+ proc_spec : Union [v0_5 .PreprocessingDescr , v0_5 .PostprocessingDescr ],
784697) -> Processing :
785698 if isinstance (proc_spec , v0_5 .BinarizeDescr ):
786699 return Binarize .from_proc_descr (proc_spec , inp .id )
0 commit comments