55# LICENSE file in the root directory of this source tree.
66
77import logging
8+
89import torch
10+ from executorch .backends .xnnpack .utils .quant_utils import is_dequant , is_quant
911from executorch .exir .dialects ._ops import ops as exir_ops
1012
1113from executorch .exir .pass_base import ExportPass , PassResult
12- from executorch .backends .xnnpack .utils .quant_utils import (
13- is_dequant ,
14- is_quant ,
15- )
1614
1715logger = logging .getLogger (__name__ )
1816logger .setLevel (logging .WARNING )
1917
18+
2019class DecomposeConcatenate (ExportPass ):
2120 """
2221 XNNPACK's Concatenate operation only supports concatenation for <= 5 tensors
@@ -25,37 +24,40 @@ class DecomposeConcatenate(ExportPass):
2524
2625 Example:
2726 Before Pass:
28- cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1);
29-
27+ cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1);
28+
3029 After Pass:
31- cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1);
32- cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1);
30+ cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1);
31+ cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1);
3332 """
3433
3534 def call (self , graph_module : torch .fx .GraphModule ):
3635 gm = graph_module
3736 for node in gm .graph .nodes :
38- if (node .op == "call_function"
39- and node .target .__name__ == "aten.cat.default" ):
37+ if (
38+ node .op == "call_function"
39+ and node .target .__name__ == "aten.cat.default"
40+ ):
4041 concat_args = node .args
4142 nodes_to_concat = node .args [0 ]
4243 if len (nodes_to_concat ) <= 5 :
4344 continue
44-
45- is_quantized = (all (is_dequant (node ) for node in nodes_to_concat )
46- and all (is_quant (node ) for node in node .users .keys ()))
45+
46+ is_quantized = all (
47+ is_dequant (node ) for node in nodes_to_concat
48+ ) and all (is_quant (node ) for node in node .users .keys ())
4749
4850 # replace the cat args with the same args but only with the first 5 nodes
4951 new_concat_args = (nodes_to_concat [:5 ],) + concat_args [1 :]
50- node .args = new_concat_args
52+ node .args = new_concat_args
5153
5254 remainder_nodes_to_concat = nodes_to_concat [5 :]
5355 with gm .graph .inserting_after (node ):
5456 logger .debug (f"Decomposing cat node { node } " )
5557 remainder_concat_node = gm .graph .create_node (
5658 "call_function" ,
5759 target = exir_ops .edge .aten .cat .default ,
58- args = ([],), # we will replace this remainder_nodes later
60+ args = ([],), # we will replace this remainder_nodes later
5961 kwargs = node .kwargs ,
6062 )
6163 node .replace_all_uses_with (remainder_concat_node )
@@ -64,11 +66,13 @@ def call(self, graph_module: torch.fx.GraphModule):
6466 # concat node
6567 q_params = nodes_to_concat [0 ].args [1 :]
6668 q_kwargs = nodes_to_concat [0 ].kwargs
67- # Quantizer enforces all the inputs and output to a concat node must share
69+ # Quantizer enforces all the inputs and output to a concat node must share
6870 # the same qparams, this means the newly inserted q/dq pair must share the
6971 # same qparams as the first quantized input in the concat node.
7072 with gm .graph .inserting_after (node ):
71- logger .debug (f"Inserting Q/DQ pair for new cat node { remainder_concat_node } " )
73+ logger .debug (
74+ f"Inserting Q/DQ pair for new cat node { remainder_concat_node } "
75+ )
7276 q_node = gm .graph .create_node (
7377 "call_function" ,
7478 target = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
@@ -82,10 +86,14 @@ def call(self, graph_module: torch.fx.GraphModule):
8286 args = (q_node ,) + q_params ,
8387 kwargs = q_kwargs ,
8488 )
85- remainder_concat_node .args = ([dq_node ] + remainder_nodes_to_concat ,) + node .args [1 :]
89+ remainder_concat_node .args = (
90+ [dq_node ] + remainder_nodes_to_concat ,
91+ ) + node .args [1 :]
8692 else :
87- remainder_concat_node .args = ([node ] + remainder_nodes_to_concat ,) + node .args [1 :]
88-
93+ remainder_concat_node .args = (
94+ [node ] + remainder_nodes_to_concat ,
95+ ) + node .args [1 :]
96+
8997 gm .recompile ()
9098 new_gm = super ().call (gm ).graph_module
9199 return PassResult (new_gm , True )
0 commit comments