Skip to content

Commit d25f6c4

Browse files
authored
[ET-VK][AOT] Enable exporting Q8 Quantized Linear + Convolution (#14043)
As title. Introduce fusion patterns to enable fusing quantized convolution and linear graph patterns into a custom op. ## Changes Introduce the concept of using custom pattern detection functions to detect graph patterns rather than solely relying on SubgraphMatcher. The issue with SubgraphMatcher is that a large number of graph patterns may need to be exported to obtain variants for different combinations of decompositions/quantization workflows. Having a custom detection function improves maintainability. Implement detection + replacement functions for quantized linear and quantized conv2d. Differential Revision: [D81323425](https://our.internmc.facebook.com/intern/diff/D81323425/)
1 parent 2c93fd2 commit d25f6c4

18 files changed

+1556
-282
lines changed

.github/workflows/pull.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,10 @@ jobs:
934934
./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear
935935
./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d
936936
937+
# Run e2e testing for selected operators. More operators will be tested via this
938+
# route in the future.
939+
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"
940+
937941
nxp-build-test:
938942
name: nxp-build-test
939943
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/vulkan/_passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ runtime.python_library(
118118
],
119119
)
120120

121+
runtime.python_library(
122+
name = "fold_qdq",
123+
srcs = ["fold_qdq.py"],
124+
visibility = [
125+
"//executorch/backends/...",
126+
],
127+
deps = [
128+
"//caffe2:torch",
129+
"//executorch/backends/vulkan:utils_lib",
130+
"//executorch/exir:pass_base",
131+
],
132+
)
133+
121134
runtime.python_library(
122135
name = "fuse_patterns",
123136
srcs = ["fuse_patterns.py"],
@@ -144,6 +157,7 @@ runtime.python_library(
144157
"//executorch/examples/...",
145158
],
146159
deps = [
160+
":fold_qdq",
147161
":fuse_patterns",
148162
":fuse_quantized_ops",
149163
":insert_prepack_nodes",

backends/vulkan/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from executorch.backends.vulkan._passes.fold_qdq import FoldQDQPass
910
from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass
1011
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
1112
FuseQuantizedOpsTransform,
@@ -30,6 +31,7 @@
3031
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
3132

3233
__all__ = [
34+
"FoldQDQPass",
3335
"FusePatternsPass",
3436
"FuseQuantizedOpsTransform",
3537
"insert_prepack_nodes",
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import executorch.backends.vulkan.utils as utils
8+
import torch
9+
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
from executorch.exir.passes import dead_code_elimination_pass
12+
13+
14+
class FoldQDQPass(ExportPass):
15+
"""
16+
Erase Q/DQ chain introduced by PT2E quantization workflow. It is assumed that all
17+
valid quant op patterns have already been fused before this pass.
18+
"""
19+
20+
def __init__(self, edge_program: torch.export.ExportedProgram):
21+
super(FoldQDQPass, self).__init__()
22+
self.edge_program = edge_program
23+
24+
def call(self, graph_module: torch.fx.GraphModule):
25+
for node in graph_module.graph.nodes:
26+
if utils.is_quant_node(node):
27+
original_node = node.args[0]
28+
assert isinstance(original_node, torch.fx.Node)
29+
# For each direct user that is a dequant node, connect the original
30+
# node to the users of the dequant node.
31+
for user in node.users:
32+
if utils.is_dequant_node(user):
33+
dq_node = user
34+
dq_node.replace_all_uses_with(original_node)
35+
36+
graph_module.recompile()
37+
dead_code_elimination_pass(graph_module)
38+
# Re-trace to validate everything is ok
39+
graph_module = super().call(graph_module).graph_module
40+
41+
return PassResult(graph_module, True)

backends/vulkan/custom_ops_lib.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional
8+
79
import executorch.backends.vulkan.patterns as vk_patterns
810
import torch.library
911

@@ -321,6 +323,135 @@ def linear_qta8a_qga4w(
321323
lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd")
322324
linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name)
323325

326+
#################
327+
## qaqw_linear ##
328+
#################
329+
330+
331+
def linear_q8ta_q8csw(
332+
x: torch.Tensor,
333+
input_scale: float,
334+
input_zero_point: int,
335+
weights: torch.Tensor,
336+
weight_sums: torch.Tensor,
337+
weight_scales: torch.Tensor,
338+
bias: Optional[torch.Tensor] = None,
339+
):
340+
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
341+
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
342+
weights,
343+
weight_scales,
344+
weight_zeros,
345+
0,
346+
-127,
347+
127,
348+
torch.int8,
349+
)
350+
351+
# Perform linear operation
352+
out = torch.nn.functional.linear(x, weights)
353+
if bias is not None:
354+
out = out + bias
355+
356+
return out
357+
358+
359+
name = "linear_q8ta_q8csw"
360+
lib.define(
361+
f"""
362+
{name}(
363+
Tensor x,
364+
float input_scale,
365+
int input_zero_point,
366+
Tensor weights,
367+
Tensor weight_sums,
368+
Tensor weight_scales,
369+
Tensor? bias = None) -> Tensor
370+
"""
371+
)
372+
lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd")
373+
qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name)
374+
375+
##################
376+
## conv2d_q8ta_q8csw ##
377+
##################
378+
379+
380+
def conv2d_q8ta_q8csw(
381+
x: torch.Tensor,
382+
input_scale: float,
383+
input_zero_point: int,
384+
weights: torch.Tensor,
385+
weight_sums: torch.Tensor,
386+
weight_scales: torch.Tensor,
387+
bias: Optional[torch.Tensor],
388+
kernel_size: list,
389+
stride: list,
390+
padding: list,
391+
dilation: list,
392+
groups: int,
393+
):
394+
IC = x.shape[1]
395+
K_h, K_w = kernel_size[0], kernel_size[1]
396+
397+
canonical_weight_K_dim = K_h * K_w * IC
398+
# Remove any padding added to output channels dim to align to a multiple of 4
399+
if weights.shape[-1] != canonical_weight_K_dim:
400+
weights = weights[:, :canonical_weight_K_dim]
401+
weight_scales = weight_scales[:canonical_weight_K_dim]
402+
if bias is not None:
403+
bias = bias[:canonical_weight_K_dim]
404+
405+
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
406+
407+
# Calculate dimensions
408+
OC = weights.shape[0]
409+
in_features = weights.shape[1]
410+
IC = in_features // (K_h * K_w)
411+
412+
# Reshape to original 4D format (OC, IC, H, W)
413+
weights = weights.view(OC, IC, K_h, K_w)
414+
415+
# Dequantize weights
416+
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
417+
weights,
418+
weight_scales,
419+
weight_zeros,
420+
0, # axis=0 for output channel quantization
421+
-127,
422+
127,
423+
torch.int8,
424+
)
425+
426+
# Perform convolution
427+
out = torch.nn.functional.conv2d(
428+
x, weights, bias, stride, padding, dilation, groups
429+
)
430+
431+
return out
432+
433+
434+
name = "conv2d_q8ta_q8csw"
435+
lib.define(
436+
f"""
437+
{name}(
438+
Tensor x,
439+
float input_scale,
440+
int input_zero_point,
441+
Tensor weights,
442+
Tensor weight_sums,
443+
Tensor weight_scales,
444+
Tensor? bias,
445+
SymInt[] kernel_size,
446+
SymInt[] stride,
447+
SymInt[] padding,
448+
SymInt[] dilation,
449+
SymInt groups) -> Tensor
450+
"""
451+
)
452+
lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd")
453+
conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name)
454+
324455
######################
325456
## apply_rotary_emb ##
326457
######################

backends/vulkan/op_registry.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,19 @@ def register_int8_mm_op():
318318
)
319319

320320

321+
@update_features(
322+
[
323+
exir_ops.edge.et_vk.linear_q8ta_q8csw.default,
324+
]
325+
)
326+
def register_qa_qw_linear():
327+
return OpFeatures(
328+
inputs_storage=utils.CONTIGUOUS_ANY,
329+
supports_prepacking=True,
330+
supports_resize=False,
331+
)
332+
333+
321334
@update_features(
322335
[
323336
exir_ops.edge.et_vk.linear_weight_int4.default,
@@ -457,6 +470,33 @@ def register_convolution_op():
457470
)
458471

459472

473+
@update_features(
474+
[
475+
exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default,
476+
]
477+
)
478+
def register_quantized_conv_op():
479+
return OpFeatures(
480+
inputs_storage=[
481+
utils.CHANNELS_PACKED_TEXTURE, # input
482+
utils.NO_STORAGE, # input_scale (non tensor)
483+
utils.NO_STORAGE, # input_zero_point (non tensor)
484+
utils.NO_STORAGE, # weight (prepacked)
485+
utils.NO_STORAGE, # weight_sums (prepacked)
486+
utils.NO_STORAGE, # weight_scales (prepacked)
487+
utils.NO_STORAGE, # bias (prepacked)
488+
utils.NO_STORAGE, # kernel_size (non tensor)
489+
utils.NO_STORAGE, # stride (non tensor)
490+
utils.NO_STORAGE, # padding (non tensor)
491+
utils.NO_STORAGE, # dilation (non tensor)
492+
utils.NO_STORAGE, # groups (non tensor)
493+
utils.NO_STORAGE, # original OC count (non tensor)
494+
],
495+
supports_resize=False,
496+
supports_prepacking=True,
497+
)
498+
499+
460500
@update_features("llama::sdpa_with_kv_cache")
461501
def register_sdpa_with_kv_cache_op():
462502
return OpFeatures(

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
vulkan_supported_ops,
2323
)
2424

25+
from executorch.backends.vulkan.patterns import PatternMatch
26+
2527
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
2628
VkMemoryLayout,
2729
VkStorageType,
@@ -41,7 +43,6 @@
4143

4244
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
4345
from torch.fx.passes.operator_support import OperatorSupportBase
44-
from torch.fx.passes.utils.matcher_utils import InternalMatch
4546

4647
# pyre-ignore
4748
ops_not_to_decompose = [
@@ -60,7 +61,7 @@ def __init__(
6061
require_dynamic_shape: bool = False,
6162
operator_blocklist: Optional[Set[OpKey]] = None,
6263
operator_allowlist: Optional[Set[OpKey]] = None,
63-
fusable_subgraphs: Optional[List[InternalMatch]] = None,
64+
fusable_subgraphs: Optional[List[PatternMatch]] = None,
6465
nn_module_blocklist: Optional[Set[str]] = None,
6566
nn_module_allowlist: Optional[Set[str]] = None,
6667
) -> None:
@@ -72,13 +73,13 @@ def __init__(
7273
operator_blocklist if operator_blocklist is not None else set()
7374
)
7475
self.operator_allowlist = operator_allowlist
75-
self.fusable_subgraphs: List[InternalMatch] = (
76+
self.fusable_subgraphs: List[PatternMatch] = (
7677
fusable_subgraphs if fusable_subgraphs is not None else []
7778
)
7879
# Create a set of all nodes that are part of fusable subgraphs for quick lookup
7980
self.fusable_nodes: Set[torch.fx.Node] = set()
8081
for match in self.fusable_subgraphs:
81-
self.fusable_nodes.update(match.nodes_map.values())
82+
self.fusable_nodes.update(match.all_nodes)
8283

8384
self.nn_module_blocklist = nn_module_blocklist
8485
self.nn_module_allowlist = nn_module_allowlist

backends/vulkan/patterns/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ runtime.python_library(
1010
"pattern_registry.py",
1111
"rope.py",
1212
"quantized_linear.py",
13+
"quantized_convolution.py",
1314
],
1415
visibility = [
1516
"//executorch/backends/...",

0 commit comments

Comments
 (0)