Skip to content

Commit a9df5ea

Browse files
author
ssjia
committed
[ET-VK][AOT] Enable exporting Q8 Quantized Linear + Convolution
As title. Introduce fusion patterns to enable fusing quantized convolution and linear graph patterns into a custom op. ## Changes Introduce the concept of using custom pattern detection functions to detect graph patterns rather than solely relying on SubgraphMatcher. The issue with SubgraphMatcher is that a large number of graph patterns may need to be exported to obtain variants for different combinations of decompositions/quantization workflows. Having a custom detection function improves maintainability. Implement detection + replacement functions for quantized linear and quantized conv2d. Differential Revision: [D81323425](https://our.internmc.facebook.com/intern/diff/D81323425/) [ghstack-poisoned]
1 parent a052433 commit a9df5ea

File tree

18 files changed

+1358
-279
lines changed

18 files changed

+1358
-279
lines changed

.github/workflows/pull.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,10 @@ jobs:
933933
./cmake-out/backends/vulkan/test/custom_ops/quantized_linear
934934
./cmake-out/backends/vulkan/test/custom_ops/quantized_conv2d
935935
936+
# Run e2e testing for selected operators. More operators will be tested via this
937+
# route in the future.
938+
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"
939+
936940
nxp-build-test:
937941
name: nxp-build-test
938942
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/vulkan/_passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ runtime.python_library(
118118
],
119119
)
120120

121+
runtime.python_library(
122+
name = "fold_qdq",
123+
srcs = ["fold_qdq.py"],
124+
visibility = [
125+
"//executorch/backends/...",
126+
],
127+
deps = [
128+
"//caffe2:torch",
129+
"//executorch/backends/vulkan:utils_lib",
130+
"//executorch/exir:pass_base",
131+
],
132+
)
133+
121134
runtime.python_library(
122135
name = "fuse_patterns",
123136
srcs = ["fuse_patterns.py"],
@@ -144,6 +157,7 @@ runtime.python_library(
144157
"//executorch/examples/...",
145158
],
146159
deps = [
160+
":fold_qdq",
147161
":fuse_patterns",
148162
":fuse_quantized_ops",
149163
":insert_prepack_nodes",

backends/vulkan/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from executorch.backends.vulkan._passes.fold_qdq import FoldQDQPass
910
from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass
1011
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
1112
FuseQuantizedOpsTransform,
@@ -30,6 +31,7 @@
3031
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
3132

3233
__all__ = [
34+
"FoldQDQPass",
3335
"FusePatternsPass",
3436
"FuseQuantizedOpsTransform",
3537
"insert_prepack_nodes",
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import executorch.backends.vulkan.utils as utils
8+
import torch
9+
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
from executorch.exir.passes import dead_code_elimination_pass
12+
13+
14+
class FoldQDQPass(ExportPass):
15+
"""
16+
Erase Q/DQ chain introduced by PT2E quantization workflow. It is assumed that all
17+
valid quant op patterns have already been fused before this pass.
18+
"""
19+
20+
def __init__(self, edge_program: torch.export.ExportedProgram):
21+
super(FoldQDQPass, self).__init__()
22+
self.edge_program = edge_program
23+
24+
def call(self, graph_module: torch.fx.GraphModule):
25+
for node in graph_module.graph.nodes:
26+
# Criteria for a foldable Q/DQ node:
27+
# - only one user (dequantize)
28+
if utils.is_quant_node(node):
29+
if len(node.users) > 1:
30+
continue
31+
32+
dq_node = None
33+
for user in node.users:
34+
if utils.is_dequant_node(user):
35+
dq_node = user
36+
37+
if dq_node is None:
38+
continue
39+
40+
original_node = node.args[0]
41+
assert isinstance(original_node, torch.fx.Node)
42+
dq_node.replace_all_uses_with(original_node)
43+
44+
graph_module.recompile()
45+
dead_code_elimination_pass(graph_module)
46+
# Re-trace to validate everything is ok
47+
graph_module = super().call(graph_module).graph_module
48+
return PassResult(graph_module, True)

backends/vulkan/custom_ops_lib.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional
8+
79
import executorch.backends.vulkan.patterns as vk_patterns
810
import torch.library
911

@@ -321,6 +323,134 @@ def linear_qta8a_qga4w(
321323
lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd")
322324
linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name)
323325

326+
#################
327+
## qaqw_linear ##
328+
#################
329+
330+
331+
def linear_q8ta_q8csw(
332+
x: torch.Tensor,
333+
input_scale: float,
334+
input_zero_point: int,
335+
qweights: torch.Tensor,
336+
weight_sums: torch.Tensor,
337+
weight_scales: torch.Tensor,
338+
bias: Optional[torch.Tensor] = None,
339+
):
340+
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
341+
qweights = qweights.transpose(0, 1)
342+
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
343+
qweights,
344+
weight_scales,
345+
weight_zeros,
346+
0,
347+
-127,
348+
127,
349+
torch.int8,
350+
)
351+
352+
# Perform linear operation
353+
out = torch.nn.functional.linear(x, weights)
354+
if bias is not None:
355+
out = out + bias
356+
357+
return out
358+
359+
360+
name = "linear_q8ta_q8csw"
361+
lib.define(
362+
f"""
363+
{name}(
364+
Tensor x,
365+
float input_scale,
366+
int input_zero_point,
367+
Tensor qweight,
368+
Tensor weight_sums,
369+
Tensor weight_scales,
370+
Tensor? bias = None) -> Tensor
371+
"""
372+
)
373+
lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd")
374+
qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name)
375+
376+
##################
377+
## conv2d_q8ta_q8csw ##
378+
##################
379+
380+
381+
def conv2d_q8ta_q8csw(
382+
x: torch.Tensor,
383+
input_scale: float,
384+
input_zero_point: int,
385+
qweights: torch.Tensor,
386+
weight_sums: torch.Tensor,
387+
weight_scales: torch.Tensor,
388+
bias: Optional[torch.Tensor],
389+
kernel_size: list,
390+
stride: list,
391+
padding: list,
392+
dilation: list,
393+
groups: int,
394+
):
395+
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
396+
397+
# Restore weight tensor from 2D format (IC * H * W, OC) back to 4D format (OC, IC, H, W)
398+
# First transpose to get (OC, IC * H * W)
399+
qweights_transposed = qweights.transpose(0, 1)
400+
401+
# Extract kernel dimensions from the provided kernel_size
402+
H, W = kernel_size[0], kernel_size[1]
403+
404+
# Calculate dimensions
405+
OC = qweights_transposed.shape[0]
406+
IC_H_W = qweights_transposed.shape[1]
407+
IC = IC_H_W // (H * W)
408+
409+
# Reshape to original 4D format (OC, IC, H, W)
410+
qweights_4d = qweights_transposed.view(OC, IC, H, W)
411+
print(qweights_4d.shape)
412+
413+
# Dequantize weights
414+
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
415+
qweights_4d,
416+
weight_scales,
417+
weight_zeros,
418+
0, # axis=0 for output channel quantization
419+
-127,
420+
127,
421+
torch.int8,
422+
)
423+
print(weights.shape)
424+
print(x.shape)
425+
426+
# Perform convolution
427+
out = torch.nn.functional.conv2d(
428+
x, weights, bias, stride, padding, dilation, groups
429+
)
430+
431+
return out
432+
433+
434+
name = "conv2d_q8ta_q8csw"
435+
lib.define(
436+
f"""
437+
{name}(
438+
Tensor x,
439+
float input_scale,
440+
int input_zero_point,
441+
Tensor qweight,
442+
Tensor weight_sums,
443+
Tensor weight_scales,
444+
Tensor? bias,
445+
SymInt[] kernel_size,
446+
SymInt[] stride,
447+
SymInt[] padding,
448+
SymInt[] dilation,
449+
SymInt groups) -> Tensor
450+
"""
451+
)
452+
lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd")
453+
conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name)
324454
######################
325455
## apply_rotary_emb ##
326456
######################

backends/vulkan/op_registry.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,19 @@ def register_int8_mm_op():
318318
)
319319

320320

321+
@update_features(
322+
[
323+
exir_ops.edge.et_vk.linear_q8ta_q8csw.default,
324+
]
325+
)
326+
def register_qa_qw_linear():
327+
return OpFeatures(
328+
inputs_storage=utils.CONTIGUOUS_ANY,
329+
supports_prepacking=True,
330+
supports_resize=False,
331+
)
332+
333+
321334
@update_features(
322335
[
323336
exir_ops.edge.et_vk.linear_weight_int4.default,
@@ -457,6 +470,32 @@ def register_convolution_op():
457470
)
458471

459472

473+
@update_features(
474+
[
475+
exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default,
476+
]
477+
)
478+
def register_quantized_conv_op():
479+
return OpFeatures(
480+
inputs_storage=[
481+
utils.CHANNELS_PACKED_TEXTURE, # input
482+
utils.NO_STORAGE, # input_scale (non tensor)
483+
utils.NO_STORAGE, # input_zero_point (non tensor)
484+
utils.NO_STORAGE, # weight (prepacked)
485+
utils.NO_STORAGE, # weight_sums (prepacked)
486+
utils.NO_STORAGE, # weight_scales (prepacked)
487+
utils.NO_STORAGE, # bias (prepacked)
488+
utils.NO_STORAGE, # kernel_size (non tensor)
489+
utils.NO_STORAGE, # stride (non tensor)
490+
utils.NO_STORAGE, # padding (non tensor)
491+
utils.NO_STORAGE, # dilation (non tensor)
492+
utils.NO_STORAGE, # groups (non tensor)
493+
],
494+
supports_resize=False,
495+
supports_prepacking=True,
496+
)
497+
498+
460499
@update_features("llama::sdpa_with_kv_cache")
461500
def register_sdpa_with_kv_cache_op():
462501
return OpFeatures(

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
vulkan_supported_ops,
2323
)
2424

25+
from executorch.backends.vulkan.patterns import PatternMatch
26+
2527
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
2628
VkMemoryLayout,
2729
VkStorageType,
@@ -41,7 +43,6 @@
4143

4244
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
4345
from torch.fx.passes.operator_support import OperatorSupportBase
44-
from torch.fx.passes.utils.matcher_utils import InternalMatch
4546

4647
# pyre-ignore
4748
ops_not_to_decompose = [
@@ -60,7 +61,7 @@ def __init__(
6061
require_dynamic_shape: bool = False,
6162
operator_blocklist: Optional[Set[OpKey]] = None,
6263
operator_allowlist: Optional[Set[OpKey]] = None,
63-
fusable_subgraphs: Optional[List[InternalMatch]] = None,
64+
fusable_subgraphs: Optional[List[PatternMatch]] = None,
6465
nn_module_blocklist: Optional[Set[str]] = None,
6566
nn_module_allowlist: Optional[Set[str]] = None,
6667
) -> None:
@@ -72,13 +73,13 @@ def __init__(
7273
operator_blocklist if operator_blocklist is not None else set()
7374
)
7475
self.operator_allowlist = operator_allowlist
75-
self.fusable_subgraphs: List[InternalMatch] = (
76+
self.fusable_subgraphs: List[PatternMatch] = (
7677
fusable_subgraphs if fusable_subgraphs is not None else []
7778
)
7879
# Create a set of all nodes that are part of fusable subgraphs for quick lookup
7980
self.fusable_nodes: Set[torch.fx.Node] = set()
8081
for match in self.fusable_subgraphs:
81-
self.fusable_nodes.update(match.nodes_map.values())
82+
self.fusable_nodes.update(match.all_nodes)
8283

8384
self.nn_module_blocklist = nn_module_blocklist
8485
self.nn_module_allowlist = nn_module_allowlist

backends/vulkan/patterns/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ runtime.python_library(
1010
"pattern_registry.py",
1111
"rope.py",
1212
"quantized_linear.py",
13+
"quantized_convolution.py",
1314
],
1415
visibility = [
1516
"//executorch/backends/...",

0 commit comments

Comments
 (0)