@@ -623,8 +623,8 @@ def quantized_conv_per_tensor(
623623 )
624624
625625
626- @impl (m , "quantized_conv_nchw .per_tensor" )
627- def quantized_conv_nchw_per_tensor (
626+ @impl (m , "quantized_conv2d_nchw .per_tensor" )
627+ def quantized_conv2d_nchw_per_tensor (
628628 input_tensor : torch .Tensor ,
629629 weight : torch .Tensor ,
630630 bias : torch .Tensor ,
@@ -679,8 +679,8 @@ def quantized_conv_nchw_per_tensor(
679679 )
680680
681681
682- @impl (m , "quantized_conv_nhwc .per_tensor" )
683- def quantized_conv_nhwc_per_tensor (
682+ @impl (m , "quantized_conv2d_nhwc .per_tensor" )
683+ def quantized_conv2d_nhwc_per_tensor (
684684 input_tensor : torch .Tensor ,
685685 weight : torch .Tensor ,
686686 bias : torch .Tensor ,
@@ -800,7 +800,7 @@ def variant(
800800 # Call the appropriate base function
801801 match layout :
802802 case "nchw" :
803- return quantized_conv_nchw_per_tensor (
803+ return quantized_conv2d_nchw_per_tensor (
804804 input_tensor ,
805805 weight ,
806806 bias ,
@@ -817,7 +817,7 @@ def variant(
817817 out_shift ,
818818 )
819819 case "nhwc" :
820- return quantized_conv_nhwc_per_tensor (
820+ return quantized_conv2d_nhwc_per_tensor (
821821 input_tensor ,
822822 weight ,
823823 bias ,
@@ -841,84 +841,92 @@ def variant(
841841 return decorator
842842
843843
844- @impl (m , "quantized_conv_nchw_asym8sxsym8s_asym8s .per_tensor" )
844+ @impl (m , "quantized_conv2d_nchw_asym8sxsym8s_asym8s .per_tensor" )
845845@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
846- def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
846+ def quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
847847
848848
849- @impl (m , "quantized_conv_nchw_asym8uxsym8u_asym8u .per_tensor" )
849+ @impl (m , "quantized_conv2d_nchw_asym8uxsym8u_asym8u .per_tensor" )
850850@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
851- def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
851+ def quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
852852
853853
854- @impl (m , "quantized_conv_nhwc_asym8sxsym8s_asym8s .per_tensor" )
854+ @impl (m , "quantized_conv2d_nhwc_asym8sxsym8s_asym8s .per_tensor" )
855855@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
856- def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
856+ def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
857857
858858
859- @impl (m , "quantized_conv_nhwc_asym8uxsym8u_asym8u .per_tensor" )
859+ @impl (m , "quantized_conv2d_nhwc_asym8uxsym8u_asym8u .per_tensor" )
860860@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
861- def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
861+ def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
862862
863863
864- @impl (m , "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s .per_tensor" )
864+ @impl (m , "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s .per_tensor" )
865865@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
866- def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
866+ def quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
867867
868868
869- @impl (m , "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u .per_tensor" )
869+ @impl (m , "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u .per_tensor" )
870870@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
871- def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
871+ def quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
872872
873873
874- @impl (m , "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s .per_tensor" )
874+ @impl (m , "quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s .per_tensor" )
875875@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
876- def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
876+ def quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
877877
878878
879- @impl (m , "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u .per_tensor" )
879+ @impl (m , "quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u .per_tensor" )
880880@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
881- def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
881+ def quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
882882
883883
884- @impl (m , "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s .per_tensor" )
884+ @impl (m , "quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s .per_tensor" )
885885@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
886- def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
886+ def quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor () -> (
887+ torch .Tensor
888+ ): ...
887889
888890
889- @impl (m , "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u .per_tensor" )
891+ @impl (m , "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u .per_tensor" )
890892@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
891- def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
893+ def quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor () -> (
894+ torch .Tensor
895+ ): ...
892896
893897
894- @impl (m , "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s .per_tensor" )
898+ @impl (m , "quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s .per_tensor" )
895899@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
896- def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
900+ def quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor () -> (
901+ torch .Tensor
902+ ): ...
897903
898904
899- @impl (m , "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u .per_tensor" )
905+ @impl (m , "quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u .per_tensor" )
900906@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
901- def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
907+ def quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor () -> (
908+ torch .Tensor
909+ ): ...
902910
903911
904- @impl (m , "quantized_conv1d_nchw_asym8sxsym8s_asym8s .per_tensor" )
912+ @impl (m , "quantized_conv1d_ncl_asym8sxsym8s_asym8s .per_tensor" )
905913@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 , is_1d = True )
906- def quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
914+ def quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
907915
908916
909- @impl (m , "quantized_conv1d_nchw_asym8uxsym8u_asym8u .per_tensor" )
917+ @impl (m , "quantized_conv1d_ncl_asym8uxsym8u_asym8u .per_tensor" )
910918@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 , is_1d = True )
911- def quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
919+ def quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
912920
913921
914- @impl (m , "quantized_conv1d_nhwc_asym8sxsym8s_asym8s .per_tensor" )
922+ @impl (m , "quantized_conv1d_nlc_asym8sxsym8s_asym8s .per_tensor" )
915923@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 , is_1d = True )
916- def quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
924+ def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
917925
918926
919- @impl (m , "quantized_conv1d_nhwc_asym8uxsym8u_asym8u .per_tensor" )
927+ @impl (m , "quantized_conv1d_nlc_asym8uxsym8u_asym8u .per_tensor" )
920928@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 , is_1d = True )
921- def quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
929+ def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
922930
923931
924932def quantized_relu_common (
0 commit comments