|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -import operator |
8 | | -from typing import Callable, List, Optional |
9 | | - |
10 | 7 | import executorch.backends.vulkan.patterns as vk_patterns |
11 | 8 |
|
12 | 9 | import torch |
13 | 10 |
|
14 | 11 | from executorch.exir import ExportedProgram |
15 | | -from executorch.exir.dialects._ops import ops as exir_ops |
16 | 12 | from executorch.exir.pass_base import ExportPass, PassResult |
17 | 13 |
|
18 | | -from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher |
19 | | - |
20 | | - |
21 | | -def fuse_pattern( |
22 | | - ep: ExportedProgram, |
23 | | - graph_module: torch.fx.GraphModule, |
24 | | - patterns: List[torch.fx.GraphModule], |
25 | | - create_replacement_func: Callable, |
26 | | -) -> int: |
27 | | - total_replaced = 0 |
28 | | - |
29 | | - for pattern in patterns: |
30 | | - sm = SubgraphMatcher(pattern.graph, ignore_literals=True) |
31 | | - matches = list(sm.match(graph_module.graph)) |
32 | | - |
33 | | - for partition_to_replace in matches: |
34 | | - create_replacement_func(ep, graph_module, partition_to_replace) |
35 | | - total_replaced += 1 |
36 | | - # Remove dead code so they won't be matched again |
37 | | - graph_module.graph.eliminate_dead_code() |
38 | | - |
39 | | - return total_replaced |
40 | | - |
41 | | - |
42 | | -## |
43 | | -## Rotary Embedding |
44 | | -## |
45 | | - |
46 | | - |
47 | | -def identify_rotary_emb_io_nodes( |
48 | | - ep: ExportedProgram, |
49 | | - graph_module: torch.fx.GraphModule, |
50 | | - match: InternalMatch, |
51 | | -) -> Optional[List[torch.fx.Node]]: |
52 | | - # Get the input placeholders (xq, xk, freqs_cos, freqs_sin) |
53 | | - placeholder_nodes = match.placeholder_nodes |
54 | | - if len(placeholder_nodes) != 4: |
55 | | - return None |
56 | | - |
57 | | - xq, xk, freqs_cos, freqs_sin = placeholder_nodes |
58 | | - |
59 | | - output_nodes = match.returning_nodes |
60 | | - if len(output_nodes) != 2: |
61 | | - return None |
62 | | - |
63 | | - xq_out, xk_out = output_nodes |
64 | | - |
65 | | - return [xq, xk, freqs_cos, freqs_sin, xq_out, xk_out] |
66 | | - |
67 | | - |
68 | | -def create_rotary_emb_custom_op( |
69 | | - ep: ExportedProgram, |
70 | | - graph_module: torch.fx.GraphModule, |
71 | | - match: InternalMatch, |
72 | | -): |
73 | | - io_nodes = identify_rotary_emb_io_nodes(ep, graph_module, match) |
74 | | - if io_nodes is None: |
75 | | - return |
76 | | - |
77 | | - assert len(io_nodes) == 6 |
78 | | - xq, xk, freqs_cos, freqs_sin, xq_out, xk_out = io_nodes |
79 | | - |
80 | | - # Create the custom op node |
81 | | - with graph_module.graph.inserting_before(xq_out): |
82 | | - rotary_emb_node = graph_module.graph.create_node( |
83 | | - "call_function", |
84 | | - exir_ops.edge.et_vk.apply_rotary_emb.default, |
85 | | - args=(xq, xk, freqs_cos, freqs_sin), |
86 | | - ) |
87 | | - |
88 | | - # The custom op returns a tuple (xq_out, xk_out) |
89 | | - # We need to extract the individual outputs |
90 | | - with graph_module.graph.inserting_after(rotary_emb_node): |
91 | | - getitem_0 = graph_module.graph.create_node( |
92 | | - "call_function", |
93 | | - operator.getitem, |
94 | | - args=(rotary_emb_node, 0), |
95 | | - ) |
96 | | - getitem_1 = graph_module.graph.create_node( |
97 | | - "call_function", |
98 | | - operator.getitem, |
99 | | - args=(rotary_emb_node, 1), |
100 | | - ) |
101 | | - |
102 | | - if hasattr(xq_out, "meta") and "val" in xq_out.meta: |
103 | | - getitem_0.meta["val"] = xq_out.meta["val"] |
104 | | - if hasattr(xk_out, "meta") and "val" in xk_out.meta: |
105 | | - getitem_1.meta["val"] = xk_out.meta["val"] |
106 | | - |
107 | | - xq_out.replace_all_uses_with(getitem_0) |
108 | | - xk_out.replace_all_uses_with(getitem_1) |
109 | | - |
110 | 14 |
|
111 | 15 | class FusePatternsPass(ExportPass): |
112 | 16 | def __init__(self, exported_program: ExportedProgram) -> None: |
|
0 commit comments