@@ -623,8 +623,8 @@ def quantized_conv_per_tensor(
623
623
)
624
624
625
625
626
- @impl (m , "quantized_conv_nchw .per_tensor" )
627
- def quantized_conv_nchw_per_tensor (
626
+ @impl (m , "quantized_conv2d_nchw .per_tensor" )
627
+ def quantized_conv2d_nchw_per_tensor (
628
628
input_tensor : torch .Tensor ,
629
629
weight : torch .Tensor ,
630
630
bias : torch .Tensor ,
@@ -679,8 +679,8 @@ def quantized_conv_nchw_per_tensor(
679
679
)
680
680
681
681
682
- @impl (m , "quantized_conv_nhwc .per_tensor" )
683
- def quantized_conv_nhwc_per_tensor (
682
+ @impl (m , "quantized_conv2d_nhwc .per_tensor" )
683
+ def quantized_conv2d_nhwc_per_tensor (
684
684
input_tensor : torch .Tensor ,
685
685
weight : torch .Tensor ,
686
686
bias : torch .Tensor ,
@@ -800,7 +800,7 @@ def variant(
800
800
# Call the appropriate base function
801
801
match layout :
802
802
case "nchw" :
803
- return quantized_conv_nchw_per_tensor (
803
+ return quantized_conv2d_nchw_per_tensor (
804
804
input_tensor ,
805
805
weight ,
806
806
bias ,
@@ -817,7 +817,7 @@ def variant(
817
817
out_shift ,
818
818
)
819
819
case "nhwc" :
820
- return quantized_conv_nhwc_per_tensor (
820
+ return quantized_conv2d_nhwc_per_tensor (
821
821
input_tensor ,
822
822
weight ,
823
823
bias ,
@@ -841,84 +841,92 @@ def variant(
841
841
return decorator
842
842
843
843
844
- @impl (m , "quantized_conv_nchw_asym8sxsym8s_asym8s .per_tensor" )
844
+ @impl (m , "quantized_conv2d_nchw_asym8sxsym8s_asym8s .per_tensor" )
845
845
@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
846
- def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
846
+ def quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
847
847
848
848
849
- @impl (m , "quantized_conv_nchw_asym8uxsym8u_asym8u .per_tensor" )
849
+ @impl (m , "quantized_conv2d_nchw_asym8uxsym8u_asym8u .per_tensor" )
850
850
@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
851
- def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
851
+ def quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
852
852
853
853
854
- @impl (m , "quantized_conv_nhwc_asym8sxsym8s_asym8s .per_tensor" )
854
+ @impl (m , "quantized_conv2d_nhwc_asym8sxsym8s_asym8s .per_tensor" )
855
855
@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
856
- def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
856
+ def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
857
857
858
858
859
- @impl (m , "quantized_conv_nhwc_asym8uxsym8u_asym8u .per_tensor" )
859
+ @impl (m , "quantized_conv2d_nhwc_asym8uxsym8u_asym8u .per_tensor" )
860
860
@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
861
- def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
861
+ def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
862
862
863
863
864
- @impl (m , "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s .per_tensor" )
864
+ @impl (m , "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s .per_tensor" )
865
865
@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
866
- def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
866
+ def quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
867
867
868
868
869
- @impl (m , "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u .per_tensor" )
869
+ @impl (m , "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u .per_tensor" )
870
870
@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
871
- def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
871
+ def quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
872
872
873
873
874
- @impl (m , "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s .per_tensor" )
874
+ @impl (m , "quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s .per_tensor" )
875
875
@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
876
- def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
876
+ def quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
877
877
878
878
879
- @impl (m , "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u .per_tensor" )
879
+ @impl (m , "quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u .per_tensor" )
880
880
@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
881
- def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
881
+ def quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
882
882
883
883
884
- @impl (m , "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s .per_tensor" )
884
+ @impl (m , "quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s .per_tensor" )
885
885
@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
886
- def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
886
+ def quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor () -> (
887
+ torch .Tensor
888
+ ): ...
887
889
888
890
889
- @impl (m , "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u .per_tensor" )
891
+ @impl (m , "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u .per_tensor" )
890
892
@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
891
- def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
893
+ def quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor () -> (
894
+ torch .Tensor
895
+ ): ...
892
896
893
897
894
- @impl (m , "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s .per_tensor" )
898
+ @impl (m , "quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s .per_tensor" )
895
899
@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
896
- def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
900
+ def quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor () -> (
901
+ torch .Tensor
902
+ ): ...
897
903
898
904
899
- @impl (m , "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u .per_tensor" )
905
+ @impl (m , "quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u .per_tensor" )
900
906
@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
901
- def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
907
+ def quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor () -> (
908
+ torch .Tensor
909
+ ): ...
902
910
903
911
904
- @impl (m , "quantized_conv1d_nchw_asym8sxsym8s_asym8s .per_tensor" )
912
+ @impl (m , "quantized_conv1d_ncl_asym8sxsym8s_asym8s .per_tensor" )
905
913
@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 , is_1d = True )
906
- def quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
914
+ def quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
907
915
908
916
909
- @impl (m , "quantized_conv1d_nchw_asym8uxsym8u_asym8u .per_tensor" )
917
+ @impl (m , "quantized_conv1d_ncl_asym8uxsym8u_asym8u .per_tensor" )
910
918
@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 , is_1d = True )
911
- def quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
919
+ def quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
912
920
913
921
914
- @impl (m , "quantized_conv1d_nhwc_asym8sxsym8s_asym8s .per_tensor" )
922
+ @impl (m , "quantized_conv1d_nlc_asym8sxsym8s_asym8s .per_tensor" )
915
923
@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 , is_1d = True )
916
- def quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
924
+ def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
917
925
918
926
919
- @impl (m , "quantized_conv1d_nhwc_asym8uxsym8u_asym8u .per_tensor" )
927
+ @impl (m , "quantized_conv1d_nlc_asym8uxsym8u_asym8u .per_tensor" )
920
928
@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 , is_1d = True )
921
- def quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
929
+ def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
922
930
923
931
924
932
def quantized_relu_common (
0 commit comments