Skip to content

Commit f5e049d

Browse files
pytorchbotssjia
andauthored
[ET-VK] AOT logic for quantized conv2d (#14669)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14648 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/333/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/333/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/332/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/333/orig Differential Revision: [D83437826](https://our.internmc.facebook.com/intern/diff/D83437826/) @diff-train-skip-merge Co-authored-by: ssjia <[email protected]>
1 parent 2b20016 commit f5e049d

File tree

12 files changed

+377
-53
lines changed

12 files changed

+377
-53
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,19 @@ runtime.python_library(
117117
],
118118
)
119119

120+
runtime.python_library(
121+
name = "replace_qdq",
122+
srcs = ["replace_qdq.py"],
123+
visibility = [
124+
"//executorch/backends/...",
125+
],
126+
deps = [
127+
"//caffe2:torch",
128+
"//executorch/backends/vulkan:utils_lib",
129+
"//executorch/exir:pass_base",
130+
],
131+
)
132+
120133
runtime.python_library(
121134
name = "fuse_patterns",
122135
srcs = ["fuse_patterns.py"],
@@ -150,6 +163,7 @@ runtime.python_library(
150163
":remove_asserts",
151164
":remove_local_scalar_dense",
152165
":remove_redundant_ops",
166+
":replace_qdq",
153167
":squeeze_unsqueeze_inputs",
154168
":tag_memory_meta_pass",
155169
]

backends/vulkan/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.backends.vulkan._passes.remove_redundant_ops import (
2323
RemoveRedundantOpsTransform,
2424
)
25+
from executorch.backends.vulkan._passes.replace_qdq import ReplaceQDQPass
2526
from executorch.backends.vulkan._passes.squeeze_unsqueeze_inputs import (
2627
SqueezeUnsqueezeInputs,
2728
)
@@ -36,6 +37,7 @@
3637
"RemoveAssertsTransform",
3738
"RemoveLocalScalarDenseOpsTransform",
3839
"RemoveRedundantOpsTransform",
40+
"ReplaceQDQPass",
3941
"SqueezeUnsqueezeInputs",
4042
"TagMemoryMetaPass",
4143
]
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import executorch.backends.vulkan.utils as utils
8+
import torch
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
class ReplaceQDQPass(ExportPass):
15+
"""
16+
Replace standard quantize/dequantize ops with custom conv-specific ops when they
17+
feed into/from quantized convolution operations. This optimization allows the
18+
backend to handle quantization more efficiently for convolution operations.
19+
"""
20+
21+
def __init__(self):
22+
super(ReplaceQDQPass, self).__init__()
23+
24+
def call(self, graph_module: torch.fx.GraphModule):
25+
# Track nodes that need to be replaced
26+
nodes_to_replace = []
27+
28+
for node in graph_module.graph.nodes:
29+
# Check if this is the custom quantized conv2d op
30+
if node.target in [
31+
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default,
32+
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default,
33+
]:
34+
# Replace quantize op feeding into conv2d (first argument is the quantized input)
35+
quantized_input_node = node.args[0]
36+
if isinstance(
37+
quantized_input_node, torch.fx.Node
38+
) and utils.is_quant_node(quantized_input_node):
39+
# Get the arguments from the original quantize node
40+
input_tensor = quantized_input_node.args[0]
41+
scale = quantized_input_node.args[1]
42+
zero_point = quantized_input_node.args[2]
43+
44+
nodes_to_replace.append(
45+
{
46+
"old_node": quantized_input_node,
47+
"new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default,
48+
"args": (input_tensor, scale, zero_point),
49+
"node_type": "quantize_input",
50+
}
51+
)
52+
53+
# Find dequantize ops that consume the output of this conv2d
54+
for user in node.users:
55+
if utils.is_dequant_node(user):
56+
# Get the arguments from the original dequantize node
57+
scale = user.args[1]
58+
zero_point = user.args[2]
59+
60+
nodes_to_replace.append(
61+
{
62+
"old_node": user,
63+
"new_target": exir_ops.edge.et_vk.dequantize_q8to_from_conv2d.default,
64+
"args": (
65+
node,
66+
scale,
67+
zero_point,
68+
), # node is the conv2d output
69+
"node_type": "dequantize_output",
70+
}
71+
)
72+
73+
# Apply the replacements
74+
for replacement in nodes_to_replace:
75+
old_node = replacement["old_node"]
76+
new_target = replacement["new_target"]
77+
new_args = replacement["args"]
78+
79+
with graph_module.graph.inserting_before(old_node):
80+
new_node = graph_module.graph.create_node(
81+
"call_function", new_target, args=new_args
82+
)
83+
new_node.meta = old_node.meta.copy()
84+
old_node.replace_all_uses_with(new_node)
85+
86+
# Clean up the graph
87+
graph_module.graph.eliminate_dead_code()
88+
graph_module.recompile()
89+
90+
# Re-trace to validate everything is ok
91+
graph_module = super().call(graph_module).graph_module
92+
93+
return PassResult(graph_module, True)

backends/vulkan/custom_ops_lib.py

Lines changed: 139 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -354,46 +354,124 @@ def linear_q8ta_q8csw(
354354
lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd")
355355
qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name)
356356

357-
#######################
358-
## conv2d_q8ta_q8csw ##
359-
#######################
357+
############################
358+
## conv2d_q8ta_q8csw_q8to ##
359+
############################
360360

361361

362-
def conv2d_q8ta_q8csw(
362+
def conv2d_q8ta_q8csw_q8to(
363363
x: torch.Tensor,
364364
input_scale: float,
365365
input_zero_point: int,
366366
weights: torch.Tensor,
367367
weight_sums: torch.Tensor,
368368
weight_scales: torch.Tensor,
369+
output_scale: float,
370+
output_zero_point: int,
369371
bias: Optional[torch.Tensor],
370372
kernel_size: list,
371373
stride: list,
372374
padding: list,
373375
dilation: list,
374376
groups: int,
375377
):
376-
IC = x.shape[1]
378+
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
379+
x, input_scale, input_zero_point, -128, 127, x.dtype
380+
)
381+
382+
# Calculate weight dimensions
383+
OC = weights.shape[0]
384+
assert OC % groups == 0, "Output channels must be divisible by groups"
385+
IC_per_group = int(x.shape[1] / groups)
377386
K_h, K_w = kernel_size[0], kernel_size[1]
378387

379-
canonical_weight_K_dim = K_h * K_w * IC
388+
orig_weight_K_dim = K_h * K_w * IC_per_group
389+
# Remove any padding added to in_features dim to align to a multiple of 4
390+
if weights.shape[-1] > orig_weight_K_dim:
391+
weights = weights[:, :orig_weight_K_dim]
392+
380393
# Remove any padding added to output channels dim to align to a multiple of 4
381-
if weights.shape[-1] != canonical_weight_K_dim:
382-
weights = weights[:, :canonical_weight_K_dim]
383-
weight_scales = weight_scales[:canonical_weight_K_dim]
394+
if weight_scales.shape[0] > OC:
395+
weight_scales = weight_scales[:OC]
384396
if bias is not None:
385-
bias = bias[:canonical_weight_K_dim]
397+
bias = bias[:OC]
398+
399+
# Reshape to original 4D format (OC, IC, H, W)
400+
weights = weights.view(OC, IC_per_group, K_h, K_w)
386401

387402
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
403+
# Dequantize weights
404+
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
405+
weights,
406+
weight_scales,
407+
weight_zeros,
408+
0, # axis=0 for output channel quantization
409+
-127,
410+
127,
411+
torch.int8,
412+
)
388413

389-
# Calculate dimensions
390-
OC = weights.shape[0]
391-
in_features = weights.shape[1]
392-
IC = in_features // (K_h * K_w)
414+
# Perform convolution
415+
out = torch.nn.functional.conv2d(
416+
x, weights, bias, stride, padding, dilation, groups
417+
)
393418

394-
# Reshape to original 4D format (OC, IC, H, W)
395-
weights = weights.view(OC, IC, K_h, K_w)
419+
out = torch.ops.quantized_decomposed.quantize_per_tensor(
420+
out, output_scale, output_zero_point, -128, 127, torch.int8
421+
)
422+
423+
return out
396424

425+
426+
name = "conv2d_q8ta_q8csw_q8to"
427+
lib.define(
428+
f"""
429+
{name}(
430+
Tensor x,
431+
float input_scale,
432+
int input_zero_point,
433+
Tensor weights,
434+
Tensor weight_sums,
435+
Tensor weight_scales,
436+
float output_scale,
437+
int output_zero_point,
438+
Tensor? bias,
439+
SymInt[] kernel_size,
440+
SymInt[] stride,
441+
SymInt[] padding,
442+
SymInt[] dilation,
443+
SymInt groups) -> Tensor
444+
"""
445+
)
446+
lib.impl(name, conv2d_q8ta_q8csw_q8to, "CompositeExplicitAutograd")
447+
conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name)
448+
449+
450+
def conv2d_q8ta_q8csw_q8to_dw(
451+
x: torch.Tensor,
452+
input_scale: float,
453+
input_zero_point: int,
454+
weights: torch.Tensor,
455+
weight_sums: torch.Tensor,
456+
weight_scales: torch.Tensor,
457+
output_scale: float,
458+
output_zero_point: int,
459+
bias: Optional[torch.Tensor],
460+
kernel_size: list,
461+
stride: list,
462+
padding: list,
463+
dilation: list,
464+
groups: int,
465+
):
466+
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
467+
x, input_scale, input_zero_point, -128, 127, x.dtype
468+
)
469+
470+
# Restore weight to original data layout
471+
K_h, K_w, OC = weights.shape
472+
weights = weights.permute(2, 0, 1).reshape(OC, 1, K_h, K_w)
473+
474+
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
397475
# Dequantize weights
398476
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
399477
weights,
@@ -410,10 +488,14 @@ def conv2d_q8ta_q8csw(
410488
x, weights, bias, stride, padding, dilation, groups
411489
)
412490

491+
out = torch.ops.quantized_decomposed.quantize_per_tensor(
492+
out, output_scale, output_zero_point, -128, 127, torch.int8
493+
)
494+
413495
return out
414496

415497

416-
name = "conv2d_q8ta_q8csw"
498+
name = "conv2d_q8ta_q8csw_q8to_dw"
417499
lib.define(
418500
f"""
419501
{name}(
@@ -423,6 +505,8 @@ def conv2d_q8ta_q8csw(
423505
Tensor weights,
424506
Tensor weight_sums,
425507
Tensor weight_scales,
508+
float output_scale,
509+
int output_zero_point,
426510
Tensor? bias,
427511
SymInt[] kernel_size,
428512
SymInt[] stride,
@@ -431,8 +515,8 @@ def conv2d_q8ta_q8csw(
431515
SymInt groups) -> Tensor
432516
"""
433517
)
434-
lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd")
435-
conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name)
518+
lib.impl(name, conv2d_q8ta_q8csw_q8to_dw, "CompositeExplicitAutograd")
519+
conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name)
436520

437521
######################
438522
## apply_rotary_emb ##
@@ -452,3 +536,39 @@ def apply_rotary_emb_impl(
452536
)
453537
lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd")
454538
apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)
539+
540+
#############################
541+
## quantize/dequantize ops ##
542+
#############################
543+
544+
545+
def quantize_q8ta_for_conv2d_impl(
546+
input: torch.Tensor,
547+
scale: float,
548+
zero_point: int,
549+
):
550+
return torch.ops.quantized_decomposed.quantize_per_tensor(
551+
input, scale, zero_point, -128, 127, torch.int8
552+
)
553+
554+
555+
name = "quantize_q8ta_for_conv2d"
556+
lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor")
557+
lib.impl(name, quantize_q8ta_for_conv2d_impl, "CompositeExplicitAutograd")
558+
quantize_q8ta_for_conv2d_op = getattr(getattr(torch.ops, namespace), name)
559+
560+
561+
def dequantize_q8to_from_conv2d_impl(
562+
input: torch.Tensor,
563+
scale: float,
564+
zero_point: int,
565+
):
566+
return torch.ops.quantized_decomposed.dequantize_per_tensor(
567+
input, scale, zero_point, -128, 127, input.dtype
568+
)
569+
570+
571+
name = "dequantize_q8to_from_conv2d"
572+
lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor")
573+
lib.impl(name, dequantize_q8to_from_conv2d_impl, "CompositeExplicitAutograd")
574+
dequantize_q8to_from_conv2d_op = getattr(getattr(torch.ops, namespace), name)

0 commit comments

Comments
 (0)