77
88import itertools
99import operator
10- from typing import List
10+ from typing import cast , List
1111
1212import torch
1313from executorch .backends .arm ._passes .arm_pass_utils import create_node
1414
15- from executorch .backends .arm .tosa_quant_utils import dq_op , q_op , QuantArgs
15+ from executorch .backends .arm .tosa_quant_utils import dq_ops , q_ops
1616from executorch .exir .dialects ._ops import ops as exir_ops
17+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload
1718from executorch .exir .pass_base import ExportPass , PassResult
1819from torch .fx import GraphModule
1920from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
@@ -61,7 +62,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6162 }
6263 for partition in matmul_partitions :
6364 quantized_input = all (
64- input_node .target == dq_op for input_node in partition .input_nodes
65+ input_node .target in dq_ops for input_node in partition .input_nodes
6566 )
6667 matmul_node = [
6768 node for node in partition .nodes if node .target in matmul_targets
@@ -74,17 +75,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
7475 input_node = self ._match_partition_to_node (
7576 node , partition .input_nodes
7677 )
77- input_node_qargs = QuantArgs .from_operator (
78- input_node .target , input_node .args
79- )
8078 # Insert new dq-node just before the mm/bmm with input_node's qparams
8179 with graph_module .graph .inserting_before (matmul_node ):
8280 # Create new dq-node before matmul
8381 dq_node = create_node (
8482 graph = graph_module .graph ,
85- op_target = dq_op ,
83+ op_target = cast ( EdgeOpOverload , input_node . target ), # type: ignore[arg-type]
8684 )
87- dq_node .args = (node , * input_node_qargs )
85+ dq_node .args = (node , * input_node . args [ 1 :] )
8886 matmul_node .replace_input_with (node , dq_node )
8987
9088 for partition_input in partition .input_nodes :
@@ -95,19 +93,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
9593 graph_module .graph .erase_node (partition_input )
9694
9795 partition_output = list (partition .output_nodes [0 ].users )[0 ]
98- quantized_output = partition_output .target == q_op
96+ quantized_output = partition_output .target in q_ops
9997 if quantized_output :
100- output_node_qargs = QuantArgs .from_operator (
101- partition_output .target , partition_output .args
102- )
10398 with graph_module .graph .inserting_after (matmul_node ):
10499 # Create q-node after matmul
105100 q_node = create_node (
106101 graph = graph_module .graph ,
107- op_target = q_op ,
102+ op_target = cast ( EdgeOpOverload , partition_output . target ), # type: ignore[arg-type]
108103 )
109104 matmul_node .replace_all_uses_with (q_node )
110- q_node .args = (matmul_node , * output_node_qargs )
105+ q_node .args = (matmul_node , * partition_output . args [ 1 :] )
111106 # Remove partition output q-node
112107 partition_output .replace_all_uses_with (
113108 partition_output .all_input_nodes [0 ]
0 commit comments