Skip to content

Commit 97e2299

Browse files
authored
Rename conv -> conv2d, conv1d_nchw -> conv1d_ncl, conv1d_nhwc -> conv1d_nlc
Differential Revision: D82329465 Pull Request resolved: #14310
1 parent 02b39bf commit 97e2299

File tree

36 files changed

+429
-409
lines changed

36 files changed

+429
-409
lines changed

backends/cadence/aot/TARGETS

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ executorch_generated_lib(
153153
"//executorch/backends/cadence/generic/operators:dequantize_per_tensor",
154154
"//executorch/backends/cadence/generic/operators:quantize_per_tensor",
155155
"//executorch/backends/cadence/generic/operators:quantized_add_out",
156-
"//executorch/backends/cadence/generic/operators:quantized_conv_nchw_out",
157-
"//executorch/backends/cadence/generic/operators:quantized_conv_nhwc_out",
156+
"//executorch/backends/cadence/generic/operators:quantized_conv2d_nchw_out",
157+
"//executorch/backends/cadence/generic/operators:quantized_conv2d_nhwc_out",
158158
"//executorch/backends/cadence/generic/operators:quantized_fully_connected_out",
159159
"//executorch/backends/cadence/generic/operators:quantized_layer_norm",
160160
"//executorch/backends/cadence/generic/operators:quantized_linear_out",

backends/cadence/aot/functions.yaml

Lines changed: 40 additions & 40 deletions
Large diffs are not rendered by default.

backends/cadence/aot/functions_hifi.yaml

Lines changed: 40 additions & 40 deletions
Large diffs are not rendered by default.

backends/cadence/aot/ops_registrations.py

Lines changed: 88 additions & 80 deletions
Large diffs are not rendered by default.

backends/cadence/aot/quantizer/patterns.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def get_anchors(
247247
)
248248

249249
def replacement_op(self) -> OpOverload:
250-
return torch.ops.cadence.quantized_conv_nchw.default
250+
return torch.ops.cadence.quantized_conv2d_nchw.default
251251

252252

253253
class Conv2dPattern(QuantizationPattern):
@@ -286,7 +286,7 @@ def get_anchors(
286286
)
287287

288288
def replacement_op(self) -> OpOverload:
289-
return torch.ops.cadence.quantized_conv_nchw.default
289+
return torch.ops.cadence.quantized_conv2d_nchw.default
290290

291291

292292
class LayerNormPattern(QuantizationPattern):
@@ -460,7 +460,7 @@ def get_anchors(
460460
)
461461

462462
def replacement_op(self) -> OpOverload:
463-
return torch.ops.cadence.quantized_conv_nchw.default
463+
return torch.ops.cadence.quantized_conv2d_nchw.default
464464

465465

466466
# Conv1d + regular relu op fusion

backends/cadence/aot/ref_implementations.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,8 @@ def quantized_conv_per_tensor(
623623
)
624624

625625

626-
@impl(m, "quantized_conv_nchw.per_tensor")
627-
def quantized_conv_nchw_per_tensor(
626+
@impl(m, "quantized_conv2d_nchw.per_tensor")
627+
def quantized_conv2d_nchw_per_tensor(
628628
input_tensor: torch.Tensor,
629629
weight: torch.Tensor,
630630
bias: torch.Tensor,
@@ -679,8 +679,8 @@ def quantized_conv_nchw_per_tensor(
679679
)
680680

681681

682-
@impl(m, "quantized_conv_nhwc.per_tensor")
683-
def quantized_conv_nhwc_per_tensor(
682+
@impl(m, "quantized_conv2d_nhwc.per_tensor")
683+
def quantized_conv2d_nhwc_per_tensor(
684684
input_tensor: torch.Tensor,
685685
weight: torch.Tensor,
686686
bias: torch.Tensor,
@@ -800,7 +800,7 @@ def variant(
800800
# Call the appropriate base function
801801
match layout:
802802
case "nchw":
803-
return quantized_conv_nchw_per_tensor(
803+
return quantized_conv2d_nchw_per_tensor(
804804
input_tensor,
805805
weight,
806806
bias,
@@ -817,7 +817,7 @@ def variant(
817817
out_shift,
818818
)
819819
case "nhwc":
820-
return quantized_conv_nhwc_per_tensor(
820+
return quantized_conv2d_nhwc_per_tensor(
821821
input_tensor,
822822
weight,
823823
bias,
@@ -841,84 +841,92 @@ def variant(
841841
return decorator
842842

843843

844-
@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor")
844+
@impl(m, "quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor")
845845
@quantized_conv_variant("nchw", torch.int8, torch.int8)
846-
def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
846+
def quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
847847

848848

849-
@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor")
849+
@impl(m, "quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor")
850850
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
851-
def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
851+
def quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
852852

853853

854-
@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor")
854+
@impl(m, "quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor")
855855
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
856-
def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
856+
def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
857857

858858

859-
@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor")
859+
@impl(m, "quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor")
860860
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
861-
def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
861+
def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
862862

863863

864-
@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor")
864+
@impl(m, "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor")
865865
@quantized_conv_variant("nchw", torch.int8, torch.int8)
866-
def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
866+
def quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
867867

868868

869-
@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor")
869+
@impl(m, "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor")
870870
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
871-
def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
871+
def quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
872872

873873

874-
@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor")
874+
@impl(m, "quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor")
875875
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
876-
def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
876+
def quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
877877

878878

879-
@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor")
879+
@impl(m, "quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor")
880880
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
881-
def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
881+
def quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
882882

883883

884-
@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor")
884+
@impl(m, "quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor")
885885
@quantized_conv_variant("nchw", torch.int8, torch.int8)
886-
def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
886+
def quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> (
887+
torch.Tensor
888+
): ...
887889

888890

889-
@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor")
891+
@impl(m, "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor")
890892
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
891-
def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
893+
def quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> (
894+
torch.Tensor
895+
): ...
892896

893897

894-
@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor")
898+
@impl(m, "quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor")
895899
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
896-
def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
900+
def quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> (
901+
torch.Tensor
902+
): ...
897903

898904

899-
@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor")
905+
@impl(m, "quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor")
900906
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
901-
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
907+
def quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> (
908+
torch.Tensor
909+
): ...
902910

903911

904-
@impl(m, "quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor")
912+
@impl(m, "quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor")
905913
@quantized_conv_variant("nchw", torch.int8, torch.int8, is_1d=True)
906-
def quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
914+
def quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
907915

908916

909-
@impl(m, "quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor")
917+
@impl(m, "quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor")
910918
@quantized_conv_variant("nchw", torch.uint8, torch.uint8, is_1d=True)
911-
def quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
919+
def quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
912920

913921

914-
@impl(m, "quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor")
922+
@impl(m, "quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor")
915923
@quantized_conv_variant("nhwc", torch.int8, torch.int8, is_1d=True)
916-
def quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
924+
def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
917925

918926

919-
@impl(m, "quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor")
927+
@impl(m, "quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor")
920928
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8, is_1d=True)
921-
def quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
929+
def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
922930

923931

924932
def quantized_relu_common(

backends/cadence/aot/replace_ops.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -787,8 +787,8 @@ class ReplaceTrivialConvWithLinear(ExportPass):
787787

788788
trivial_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = {
789789
exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default,
790-
exir_ops.edge.cadence.quantized_conv_nchw.default: exir_ops.edge.cadence.quantized_linear.default,
791-
exir_ops.edge.cadence.quantized_conv_nhwc.default: exir_ops.edge.cadence.quantized_linear.default,
790+
exir_ops.edge.cadence.quantized_conv2d_nchw.default: exir_ops.edge.cadence.quantized_linear.default,
791+
exir_ops.edge.cadence.quantized_conv2d_nhwc.default: exir_ops.edge.cadence.quantized_linear.default,
792792
}
793793

794794
def call_operator(self, op, args, kwargs, meta):
@@ -800,8 +800,8 @@ def call_operator(self, op, args, kwargs, meta):
800800
# extra args holding at least the zero point and scale of input, weight, bias,
801801
# and output tensor.
802802
quantized_op = (
803-
op == exir_ops.edge.cadence.quantized_conv_nchw.default
804-
or op == exir_ops.edge.cadence.quantized_conv_nhwc.default
803+
op == exir_ops.edge.cadence.quantized_conv2d_nchw.default
804+
or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.default
805805
)
806806
assert (len(args) == 8 and not quantized_op) or (
807807
len(args) >= 12 and quantized_op
@@ -979,18 +979,18 @@ def call_operator(
979979
) -> ProxyValue:
980980
if op not in {
981981
exir_ops.edge.cadence.convolution.default,
982-
exir_ops.edge.cadence.quantized_conv_nchw.default,
982+
exir_ops.edge.cadence.quantized_conv2d_nchw.default,
983983
}:
984984
return super().call_operator(op, args, kwargs, meta)
985985

986-
quantized_op = op == exir_ops.edge.cadence.quantized_conv_nchw.default
986+
quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.default
987987

988988
if not quantized_op and len(args) == 8 and args[-1] is True:
989989
# Already in NHWC layout.
990990
return super().call_operator(op, args, kwargs, meta)
991991

992992
new_op = (
993-
exir_ops.edge.cadence.quantized_conv_nhwc.default
993+
exir_ops.edge.cadence.quantized_conv2d_nhwc.default
994994
if quantized_op
995995
else exir_ops.edge.cadence.convolution.default
996996
)
@@ -1067,8 +1067,8 @@ class ReplaceConvWithIm2RowAndLinear(ExportPass):
10671067
# decompose to.
10681068
conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = {
10691069
exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default,
1070-
exir_ops.edge.cadence.quantized_conv_nchw.default: exir_ops.edge.cadence.quantized_linear.default,
1071-
exir_ops.edge.cadence.quantized_conv_nhwc.default: exir_ops.edge.cadence.quantized_linear.default,
1070+
exir_ops.edge.cadence.quantized_conv2d_nchw.default: exir_ops.edge.cadence.quantized_linear.default,
1071+
exir_ops.edge.cadence.quantized_conv2d_nhwc.default: exir_ops.edge.cadence.quantized_linear.default,
10721072
}
10731073

10741074
def call_operator(self, op, args, kwargs, meta):
@@ -1077,8 +1077,8 @@ def call_operator(self, op, args, kwargs, meta):
10771077

10781078
# Get the relevant args from convolution node.
10791079
quantized_op = (
1080-
op == exir_ops.edge.cadence.quantized_conv_nchw.default
1081-
or op == exir_ops.edge.cadence.quantized_conv_nhwc.default
1080+
op == exir_ops.edge.cadence.quantized_conv2d_nchw.default
1081+
or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.default
10821082
)
10831083
assert (len(args) == 8 and not quantized_op) or (
10841084
len(args) >= 12 and quantized_op
@@ -1110,7 +1110,7 @@ def call_operator(self, op, args, kwargs, meta):
11101110
# channel_last layout is specified by the channel_last arg of conv
11111111
# op, which is either the last argument (15th) or implicitely False
11121112
# if the op is quantized, or the last argument if not.
1113-
channel_last = op == exir_ops.edge.cadence.quantized_conv_nhwc.default
1113+
channel_last = op == exir_ops.edge.cadence.quantized_conv2d_nhwc.default
11141114
# The weight tensor is [out_channels, in_channels, X] for NCHW layout,
11151115
# and [out_channels, X, in_channels] for NHWC layout. Here, X is the
11161116
# kernel_width for conv1d, and X = kernel_height * kernel_width for
@@ -1622,12 +1622,12 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
16221622
exir_ops.edge.cadence.quantized_add.per_tensor,
16231623
[1, 2, 4, 5],
16241624
),
1625-
exir_ops.edge.cadence.quantized_conv_nchw: (
1626-
exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
1625+
exir_ops.edge.cadence.quantized_conv2d_nchw: (
1626+
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
16271627
[8, 9, 12, 13],
16281628
),
1629-
exir_ops.edge.cadence.quantized_conv_nhwc: (
1630-
exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
1629+
exir_ops.edge.cadence.quantized_conv2d_nhwc: (
1630+
exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor,
16311631
[8, 9, 12, 13],
16321632
),
16331633
exir_ops.edge.cadence.quantized_fully_connected: (

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -906,40 +906,40 @@ def test_quantized_conv_per_tensor(
906906

907907
convs = [
908908
(
909-
torch.ops.cadence.quantized_conv_nchw.per_tensor
909+
torch.ops.cadence.quantized_conv2d_nchw.per_tensor
910910
if memory_format == torch.contiguous_format
911-
else torch.ops.cadence.quantized_conv_nhwc.per_tensor
911+
else torch.ops.cadence.quantized_conv2d_nhwc.per_tensor
912912
)
913913
]
914914

915915
optimized_convs = []
916916
if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8:
917917
if memory_format == torch.contiguous_format:
918918
optimized_convs = [
919-
torch.ops.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor,
920-
torch.ops.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor,
921-
torch.ops.cadence.quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor,
919+
torch.ops.cadence.quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor,
920+
torch.ops.cadence.quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor,
921+
torch.ops.cadence.quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor,
922922
]
923923

924924
else:
925925
optimized_convs = [
926-
torch.ops.cadence.quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor,
927-
torch.ops.cadence.quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor,
928-
torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor,
926+
torch.ops.cadence.quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor,
927+
torch.ops.cadence.quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor,
928+
torch.ops.cadence.quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor,
929929
]
930930
elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8:
931931
if memory_format == torch.contiguous_format:
932932
optimized_convs = [
933-
torch.ops.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor,
934-
torch.ops.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor,
935-
torch.ops.cadence.quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor,
933+
torch.ops.cadence.quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor,
934+
torch.ops.cadence.quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor,
935+
torch.ops.cadence.quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor,
936936
]
937937

938938
else:
939939
optimized_convs = [
940-
torch.ops.cadence.quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor,
941-
torch.ops.cadence.quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor,
942-
torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor,
940+
torch.ops.cadence.quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor,
941+
torch.ops.cadence.quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor,
942+
torch.ops.cadence.quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor,
943943
]
944944

945945
convs.extend(optimized_convs)

0 commit comments

Comments
 (0)