33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6+ from executorch .backends .qualcomm ._passes .utils import find_patterns
67import torch
78
8- from executorch .backends .qualcomm .builders .node_visitor import dq_ops
9- from executorch .backends .qualcomm .builders .utils import get_parameter , is_parameter
109from executorch .exir .dialects ._ops import ops as exir_ops
1110from executorch .exir .pass_base import ExportPass , PassResult
12- from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
1311
12+ def _is_node (node ): return isinstance (node , torch .fx .Node )
13+ def _is_call (node ): return _is_node (node ) and node .op == 'call_function'
14+ def _is_placeholder (node ): return _is_node (node ) and node .op == 'placeholder'
15+ def _is_get_attr (node ): return _is_node (node ) and node .op == 'get_attr'
16+ def _is_add (node ): return _is_call (node ) and node .target in [exir_ops .edge .aten .add .Tensor , torch .ops .aten .add .Tensor ]
17+ def _is_mean (node ): return _is_call (node ) and node .target in [exir_ops .edge .aten .mean .dim , torch .ops .aten .mean .dim ]
18+ def _is_mul (node ): return _is_call (node ) and node .target in [exir_ops .edge .aten .mul .Tensor , torch .ops .aten .mul .Tensor ]
19+ def _is_pow (node ): return _is_call (node ) and node .target in [exir_ops .edge .aten .pow .Tensor_Tensor , torch .ops .aten .pow .Tensor_Scalar ]
20+ def _is_rsqrt (node ): return _is_call (node ) and node .target in [exir_ops .edge .aten .rsqrt .default , torch .ops .aten .rsqrt .default ]
1421
1522class RecomposeRmsNorm (ExportPass ):
1623 """
1724 Merge decomposed operators back to one super node.
18- TODO: After replacing export_to_edge with to_edge_transform_and_lowering
19- in examples/models/llama/export_llama_lib.py, this pass can be removed
2025 """
2126
22- def __init__ (self , edge_program : torch . export . ExportedProgram ):
27+ def __init__ (self , quantization_capture = False ):
2328 super (RecomposeRmsNorm , self ).__init__ ()
24- self .edge_program = edge_program
25-
26- def _get_eps_node (self , nodes ):
27- # eps: one of inputs of add node
28- add_node = [n for n in nodes if hasattr (n , "name" ) and "add" in n .name ][0 ]
29- for a in add_node .args :
30- if isinstance (a , float ) or a .op != "call_function" :
31- return a
32-
33- def _get_gamma_node (self , output_node ):
34- # gamma: one of inputs of output node
35- for a in output_node .args :
36- if a .op != "call_function" or a .target in dq_ops :
37- return a
29+ self .rms_norm_target = exir_ops .edge .aten .rms_norm .default
30+ self .skip_targets = [exir_ops .edge .aten .to .dtype ,]
31+ if quantization_capture :
32+ self .rms_norm_target = torch .ops .aten .rms_norm .default
33+ self .skip_targets = [torch .ops .aten .to .dtype ,]
34+
35+ def _get_input_node (self , node ):
36+ input_node = node .args [0 ]
37+ while input_node .target in self .skip_targets :
38+ input_node = input_node .args [0 ]
39+ return input_node
3840
3941 def call (self , graph_module : torch .fx .GraphModule ):
4042 graph = graph_module .graph
41- partitions = get_source_partitions (
42- graph , [torch .nn .RMSNorm , torch .ops .aten .rms_norm .default ]
43- )
44- for _ , src_partitions in partitions .items ():
45- for src_partition in src_partitions :
46- input_len = len (src_partition .input_nodes )
47- if input_len == 1 :
48- input_node = src_partition .input_nodes [0 ]
49- elif input_len == 2 :
50- inp_0 , inp_1 = src_partition .input_nodes
51- input_node = inp_0 if len (inp_0 .users ) == 2 else inp_1
52- else :
53- raise RuntimeError (
54- f"Found a edge case of rms_node partition { src_partition } , which has { input_len } inputs"
55- )
5643
57- output_node = src_partition .output_nodes [0 ]
58- eps = self ._get_eps_node (src_partition .nodes )
59- if isinstance (eps , torch .fx .Node ) and is_parameter (
60- eps , self .edge_program
61- ):
62- eps = get_parameter (eps , self .edge_program ).item ()
63- gamma_node = self ._get_gamma_node (output_node )
44+ # Root Mean Square normalization math equivalent implementation
45+ patterns = [
46+ # transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
47+ [_is_mul , '*' , _is_mul , _is_rsqrt , _is_add , _is_mean , _is_pow ],
48+ # executorch.examples.models.llama.norm.RMSNorm
49+ [_is_mul , '*' , _is_mul , _is_rsqrt , _is_add , _is_mean , _is_mul ],
50+ ]
51+
52+ for node in graph .nodes :
53+ if not _is_mul (node ):
54+ continue
55+
56+ rms_norm_patterns = [pattern for pattern in find_patterns (node , patterns ) if pattern is not None ]
57+
58+ if len (rms_norm_patterns )> 0 :
59+ # Use first matched pattern
60+ rms_norm_pattern = rms_norm_patterns [0 ][0 ]
61+ last_mul_node = rms_norm_pattern [0 ]
62+ gamma_node = None
63+ # weight should be a constant
64+ for arg in last_mul_node .args :
65+ if _is_get_attr (arg ) or _is_placeholder (arg ):
66+ gamma_node = arg
67+ if gamma_node is None :
68+ continue
69+
70+ eps = rms_norm_pattern [4 ].args [1 ]
71+ if isinstance (eps , torch .fx .Node ):
72+ eps = eps .meta ['val' ].constant .item ()
73+ input_node = self ._get_input_node (rms_norm_pattern [6 ])
6474
65- with graph .inserting_before (output_node ):
75+ with graph .inserting_before (last_mul_node ):
6676 # args schema
6777 # (Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor
6878 rms_node = graph .create_node (
6979 "call_function" ,
70- exir_ops . edge . aten . rms_norm . default ,
80+ self . rms_norm_target ,
7181 (
7282 input_node ,
7383 list (gamma_node .meta ["val" ].shape ),
7484 gamma_node ,
7585 eps ,
7686 ),
7787 )
78- users = output_node .users .copy ()
88+ users = last_mul_node .users .copy ()
7989 for user in users :
80- user .replace_input_with (output_node , rms_node )
90+ user .replace_input_with (last_mul_node , rms_node )
8191 # copy metadata
82- rms_node .meta = output_node .meta
92+ rms_node .meta = last_mul_node .meta
8393
8494 graph .eliminate_dead_code ()
8595 graph_module .recompile ()
8696 return PassResult (graph_module , True )
97+
0 commit comments