Skip to content

Commit 78b8321

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK] Enable IntxWeightOnlyConfig"
## Motivation Be able to test Vulkan lowering via optimum-executorch. ## Context Very similar to the below PR, Int4 weight only quantization is currently enabled in Vulkan via a custom source transform quantizer that replaces linear layers with a custom linear layer that calls a custom weight only quantized linear op. This diff aims to make it so that no Vulkan specific source transforms need to be applied by adding a fusion pattern for weight only quantized linear. ## Changes * Introduce a fusable graph pattern for weight only quantized linear * Add fusion logic for weight only quantized linear in the fuse patterns pass * Add `4w` qmode to the export llama script Differential Revision: [D80293302](https://our.internmc.facebook.com/intern/diff/D80293302/) [ghstack-poisoned]
1 parent 54722ae commit 78b8321

File tree

1 file changed

+0
-96
lines changed

1 file changed

+0
-96
lines changed

backends/vulkan/_passes/fuse_patterns.py

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,109 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import operator
8-
from typing import Callable, List, Optional
9-
107
import executorch.backends.vulkan.patterns as vk_patterns
118

129
import torch
1310

1411
from executorch.exir import ExportedProgram
15-
from executorch.exir.dialects._ops import ops as exir_ops
1612
from executorch.exir.pass_base import ExportPass, PassResult
1713

18-
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
19-
20-
21-
def fuse_pattern(
22-
ep: ExportedProgram,
23-
graph_module: torch.fx.GraphModule,
24-
patterns: List[torch.fx.GraphModule],
25-
create_replacement_func: Callable,
26-
) -> int:
27-
total_replaced = 0
28-
29-
for pattern in patterns:
30-
sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
31-
matches = list(sm.match(graph_module.graph))
32-
33-
for partition_to_replace in matches:
34-
create_replacement_func(ep, graph_module, partition_to_replace)
35-
total_replaced += 1
36-
# Remove dead code so they won't be matched again
37-
graph_module.graph.eliminate_dead_code()
38-
39-
return total_replaced
40-
41-
42-
##
43-
## Rotary Embedding
44-
##
45-
46-
47-
def identify_rotary_emb_io_nodes(
48-
ep: ExportedProgram,
49-
graph_module: torch.fx.GraphModule,
50-
match: InternalMatch,
51-
) -> Optional[List[torch.fx.Node]]:
52-
# Get the input placeholders (xq, xk, freqs_cos, freqs_sin)
53-
placeholder_nodes = match.placeholder_nodes
54-
if len(placeholder_nodes) != 4:
55-
return None
56-
57-
xq, xk, freqs_cos, freqs_sin = placeholder_nodes
58-
59-
output_nodes = match.returning_nodes
60-
if len(output_nodes) != 2:
61-
return None
62-
63-
xq_out, xk_out = output_nodes
64-
65-
return [xq, xk, freqs_cos, freqs_sin, xq_out, xk_out]
66-
67-
68-
def create_rotary_emb_custom_op(
69-
ep: ExportedProgram,
70-
graph_module: torch.fx.GraphModule,
71-
match: InternalMatch,
72-
):
73-
io_nodes = identify_rotary_emb_io_nodes(ep, graph_module, match)
74-
if io_nodes is None:
75-
return
76-
77-
assert len(io_nodes) == 6
78-
xq, xk, freqs_cos, freqs_sin, xq_out, xk_out = io_nodes
79-
80-
# Create the custom op node
81-
with graph_module.graph.inserting_before(xq_out):
82-
rotary_emb_node = graph_module.graph.create_node(
83-
"call_function",
84-
exir_ops.edge.et_vk.apply_rotary_emb.default,
85-
args=(xq, xk, freqs_cos, freqs_sin),
86-
)
87-
88-
# The custom op returns a tuple (xq_out, xk_out)
89-
# We need to extract the individual outputs
90-
with graph_module.graph.inserting_after(rotary_emb_node):
91-
getitem_0 = graph_module.graph.create_node(
92-
"call_function",
93-
operator.getitem,
94-
args=(rotary_emb_node, 0),
95-
)
96-
getitem_1 = graph_module.graph.create_node(
97-
"call_function",
98-
operator.getitem,
99-
args=(rotary_emb_node, 1),
100-
)
101-
102-
if hasattr(xq_out, "meta") and "val" in xq_out.meta:
103-
getitem_0.meta["val"] = xq_out.meta["val"]
104-
if hasattr(xk_out, "meta") and "val" in xk_out.meta:
105-
getitem_1.meta["val"] = xk_out.meta["val"]
106-
107-
xq_out.replace_all_uses_with(getitem_0)
108-
xk_out.replace_all_uses_with(getitem_1)
109-
11014

11115
class FusePatternsPass(ExportPass):
11216
def __init__(self, exported_program: ExportedProgram) -> None:

0 commit comments

Comments
 (0)