55
66# pyre-unsafe
77
8+ import itertools
89import operator
10+ import typing
911from typing import final , Optional , Sequence , Type
1012
13+ import torch
14+
1115import torch .fx as fx
16+ from executorch .backends .arm ._passes .arm_pass_utils import get_first_fake_tensor
17+ from executorch .backends .arm ._passes .fuse_quantized_activation_pass import (
18+ FuseQuantizedActivationPass ,
19+ )
1220from executorch .backends .arm .tosa_specification import TosaSpecification
1321from executorch .exir .dialects ._ops import ops as exir_ops
1422from torch .fx .passes .operator_support import any_chain , chain , OperatorSupportBase
23+ from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
1524
1625
1726class SupportedTOSAOperatorCheck (OperatorSupportBase ):
@@ -27,7 +36,9 @@ def __init__(self, tosa_spec: TosaSpecification):
2736 targets : list [str ] = []
2837
2938 @final
30- def is_node_supported (self , submodules , node : fx .Node ) -> bool :
39+ def is_node_supported (
40+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
41+ ) -> bool :
3142 if node .target not in self .targets :
3243 return False
3344 return self .is_node_tosa_supported (node , self .tosa_spec )
@@ -75,6 +86,10 @@ def tosa_support_factory(
7586 tosa_spec : TosaSpecification ,
7687 additional_checks : Optional [Sequence [OperatorSupportBase ]] = None ,
7788) -> OperatorSupportBase :
89+ negative_checks : list [OperatorSupportBase ] = []
90+ if not tosa_spec .support_float ():
91+ negative_checks .append (NeedsDecompositionCheck ())
92+ negative_checks .append (CheckProperQuantization ())
7893 return chain (
7994 any_chain (
8095 BaseTOSASupportList (),
@@ -83,13 +98,16 @@ def tosa_support_factory(
8398 for check in get_registered_tosa_support_checks (tosa_spec )
8499 ),
85100 ),
101+ * negative_checks ,
86102 * additional_checks if additional_checks else [],
87103 )
88104
89105
90106class BaseTOSASupportList (OperatorSupportBase ):
91107
92- def is_node_supported (self , submodules , node : fx .Node ) -> bool :
108+ def is_node_supported (
109+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
110+ ) -> bool :
93111 supported = node .op == "call_function" and node .target in [
94112 exir_ops .edge .aten .abs .default ,
95113 exir_ops .edge .aten .add .Tensor ,
@@ -150,3 +168,154 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
150168 ]
151169
152170 return supported
171+
172+
173+ class NeedsDecompositionCheck (OperatorSupportBase ):
174+ """
175+ Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding
176+ the operator, and to get optimal quantization parameters for each operator. This check will reject operators
177+ that need to be decomposed.
178+ """
179+
180+ def is_node_supported (
181+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
182+ ) -> bool :
183+
184+ if node .op != "call_function" :
185+ return True
186+ if node .target == exir_ops .edge .aten .mean .dim :
187+ dim = node .args [1 ]
188+ return dim == [- 1 , - 2 ]
189+ needs_decomp = node .target in [
190+ exir_ops .edge .aten .div .Tensor ,
191+ exir_ops .edge .aten ._native_batch_norm_legit_no_training .default ,
192+ exir_ops .edge .aten .native_layer_norm .default ,
193+ exir_ops .edge .aten .mean .dim ,
194+ exir_ops .edge .aten ._softmax .default ,
195+ exir_ops .edge .aten ._log_softmax .default ,
196+ exir_ops .edge .aten .var .correction ,
197+ exir_ops .edge .aten .var .dim ,
198+ ]
199+ return not needs_decomp
200+
201+
202+ class CheckProperQuantization (OperatorSupportBase ):
203+ """
204+ For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize
205+ and dequantize nodes surrounds the node. This is neccessary for table operators and operators that need to rescale
206+ activations.
207+ """
208+
209+ dq_op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
210+ q_op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
211+
212+ def _is_matmul_node_supported (
213+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
214+ ):
215+ """
216+ Find the matmul source partition containing this node and check that all its inputs and outputs are quantized.
217+ """
218+ for graph_module in submodules .values ():
219+ graph_module = typing .cast (fx .GraphModule , graph_module )
220+ matmul_partitions = get_source_partitions (
221+ graph_module .graph ,
222+ [
223+ torch .matmul ,
224+ ],
225+ None ,
226+ )
227+ matmul_partitions = list (
228+ itertools .chain .from_iterable (matmul_partitions .values ())
229+ )
230+ matched_partition = None
231+ for partition in matmul_partitions :
232+ if node in partition .nodes :
233+ matched_partition = partition
234+ if matched_partition is not None :
235+ input_quantized = all (
236+ input_node .target == self .dq_op
237+ for input_node in matched_partition .input_nodes
238+ )
239+ if not input_quantized :
240+ return False
241+ output_quantized = all (
242+ output_node_user .target == self .q_op
243+ for output_node_user in matched_partition .output_nodes [0 ].users
244+ )
245+ if not output_quantized :
246+ return False
247+ else :
248+ return False
249+
250+ return True
251+
252+ def is_node_supported (
253+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
254+ ) -> bool :
255+ output_quantized = False
256+ input_quantized = False
257+ if node .target not in (
258+ exir_ops .edge .aten .add .Tensor ,
259+ exir_ops .edge .aten .avg_pool2d .default ,
260+ exir_ops .edge .aten .bmm .default ,
261+ exir_ops .edge .aten .convolution .default ,
262+ exir_ops .edge .aten .exp .default ,
263+ exir_ops .edge .aten .hardtanh .default ,
264+ exir_ops .edge .aten .linear .default ,
265+ exir_ops .edge .aten .log .default ,
266+ exir_ops .edge .aten .max_pool2d_with_indices .default ,
267+ exir_ops .edge .aten .mm .default ,
268+ exir_ops .edge .aten .mul .Tensor ,
269+ exir_ops .edge .aten .reciprocal .default ,
270+ exir_ops .edge .aten .relu .default ,
271+ exir_ops .edge .aten .rsqrt .default ,
272+ exir_ops .edge .aten .sigmoid .default ,
273+ exir_ops .edge .aten .sub .Tensor ,
274+ exir_ops .edge .aten .tanh .default ,
275+ exir_ops .edge .aten .upsample_nearest2d .vec ,
276+ ):
277+ return True
278+ elif node .target in (
279+ exir_ops .edge .aten .bmm .default ,
280+ exir_ops .edge .aten .mm .default ,
281+ ):
282+ source_fn_stack : tuple [typing .Any ] = node .meta .get ("source_fn_stack" , [])
283+ if len (source_fn_stack ) > 0 :
284+ if source_fn_stack [- 1 ][1 ] in (torch .matmul ,):
285+ return self ._is_matmul_node_supported (submodules , node )
286+
287+ elif node .target in (exir_ops .edge .aten .max_pool2d_with_indices .default ,):
288+ users = node .users
289+ output_quantized = all (
290+ user .target == operator .getitem
291+ and all (user_user .target == self .q_op for user_user in user .users )
292+ for user in users
293+ )
294+ elif FuseQuantizedActivationPass ._is_fuseable_input (node ):
295+ users = node .users
296+ output_quantized = all (
297+ FuseQuantizedActivationPass ._is_fuseable_quantized_activation (user )
298+ for user in users
299+ )
300+ elif FuseQuantizedActivationPass ._is_fuseable_quantized_activation (node ):
301+ input_node = node .all_input_nodes [0 ]
302+ input_quantized = FuseQuantizedActivationPass ._is_fuseable_input (input_node )
303+
304+ input_quantized = input_quantized or all (
305+ (input_node .target == self .dq_op )
306+ or (not get_first_fake_tensor (input_node ).dtype .is_floating_point )
307+ for input_node in node .all_input_nodes
308+ )
309+
310+ if not input_quantized :
311+ return False
312+
313+ output_quantized = output_quantized or all (
314+ (output_node .target == self .q_op )
315+ or (not get_first_fake_tensor (output_node ).dtype .is_floating_point )
316+ for output_node in node .users
317+ )
318+
319+ if not output_quantized :
320+ return False
321+ return True
0 commit comments