@@ -146,6 +146,61 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
146146 return checker
147147
148148
149+ def _is_quantized_constant (node : torch .fx .Node ) -> bool :
150+ if node .target not in (
151+ exir_ops .edge .aten .full_like .default ,
152+ * ComputeConstantOpsAOTPass .targeted_ops ,
153+ ):
154+ return False
155+
156+ users = tuple (node .users )
157+ if users and all (user .target in Q_OPS for user in users ):
158+ # The node feeds directly into only quantized ops.
159+ return True
160+
161+ for user in users :
162+ if user .target == exir_ops .edge .dim_order_ops ._to_dim_order_copy .default :
163+ dim_order_dtype = get_first_fake_tensor (user ).dtype
164+ if dim_order_dtype .is_complex or dim_order_dtype .is_floating_point :
165+ return False
166+ else :
167+ return False
168+
169+ return len (users ) > 0
170+
171+
172+ def is_quantized (node : torch .fx .Node ) -> bool :
173+ """Checks if the node is quantized.
174+
175+ A node is considered quantized if any of the following is true:
176+ - Its output dtype is not floating point or complex => integer
177+ - It is an op that produces a constant that in turn feeds only quantized users
178+ - It has been marked as quantized in the ArmAnnotationInfo custom meta.
179+
180+ Args:
181+ node (torch.fx.Node): The FX node to check.
182+
183+ Returns:
184+ bool: True if the node is quantized, False otherwise.
185+ """
186+
187+ node_dtype = get_first_fake_tensor (node ).dtype
188+ # Integer-like dtype implies the node is already quantized.
189+ if not node_dtype .is_complex and not node_dtype .is_floating_point :
190+ return True
191+
192+ # Nodes introduced during lowering that exclusively feed quantized users.
193+ if _is_quantized_constant (node ):
194+ return True
195+
196+ # Finally, fall back to the explicit annotation emitted by Arm passes.
197+ custom_meta = node .meta .get ("custom" , {})
198+ if ArmAnnotationInfo .CUSTOM_META_KEY in custom_meta :
199+ return custom_meta [ArmAnnotationInfo .CUSTOM_META_KEY ]["quantized" ]
200+
201+ return False
202+
203+
149204def get_registered_tosa_support_checks (
150205 tosa_spec : TosaSpecification ,
151206) -> list [Type [SupportedTOSAOperatorCheck ]]:
@@ -194,9 +249,11 @@ def tosa_support_factory(
194249 ControlFlowOpSupported (exported_program , tosa_spec , reporter ),
195250 ]
196251
197- if tosa_spec .support_integer ():
252+ if tosa_spec .support_integer () and tosa_spec .support_float ():
253+ positive_checks .append (TOSAProINTFPSupportList ())
254+ elif tosa_spec .support_integer ():
198255 positive_checks .append (TOSAProINTSupportList ())
199- if tosa_spec .support_float ():
256+ elif tosa_spec .support_float ():
200257 positive_checks .append (TOSAProFPSupportList ())
201258 # TODO: Refactor to use TOSAProSupportLists + negtive checks
202259 positive_checks += [
@@ -268,6 +325,27 @@ def is_node_supported(
268325 return node .op == "call_function" and node .target in TOSA_PRO_FP_SupportList
269326
270327
328+ class TOSAProINTFPSupportList (OperatorSupportBase ):
329+ """
330+ TOSA_PRO_INT_FP_SupportList:
331+ Ops supported in INT+FP profile via native TOSA ops, decomposition/transformation, pre-compute, or TableOp.
332+ """
333+
334+ def is_node_supported (
335+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
336+ ) -> bool :
337+ if node .op != "call_function" :
338+ return False
339+
340+ # Select list based on whether the node is quantized.
341+ if is_quantized (node ) or node .target in (* Q_OPS , * DQ_OPS ):
342+ support_list = TOSA_PRO_INT_SupportList
343+ else :
344+ support_list = TOSA_PRO_FP_SupportList
345+
346+ return node .target in support_list
347+
348+
271349class CheckArmQuantized (OperatorSupportBase ):
272350 """
273351 Check if the node was marked as quantized in the Arm backend.
@@ -278,60 +356,14 @@ class CheckArmQuantized(OperatorSupportBase):
278356 def __init__ (self , reporter : WhyNoPartitionReporter ):
279357 self .reporter = reporter
280358
281- def _is_quantized (self , node : torch .fx .Node ) -> bool :
282- """Checks if the node is quantized.
283-
284- A node is considered quantized if at least one criteria is met:
285- - Its dtype is not floating point or complex => integer
286- - It is one of the special cases where the node has been created in to_edge, e.g.
287- .Scalar operations that have been promoted .Tensor operations
288- where the scalar is replaced by a full op.
289- - It has been marked as quantized in the ArmAnnotationInfo custom meta.
290-
291- Args:
292- node (torch.fx.Node): The FX node to check.
293-
294- Returns:
295- bool: True if the node is quantized, False otherwise.
296- """
297- node_dtype = get_first_fake_tensor (node ).dtype
298- if not node_dtype .is_complex and not node_dtype .is_floating_point :
299- return True
300- if node .target in (
301- exir_ops .edge .aten .full_like .default ,
302- * ComputeConstantOpsAOTPass .targeted_ops ,
303- ):
304- # Special cases where nodes have been created in to_edge, e.g.
305- # .Scalar operations that have been promoted .Tensor operations
306- # where the scalar is replaced by a full op.
307- if all (user .target in Q_OPS for user in node .users ):
308- return True
309- for user in node .users :
310- if (
311- user .target
312- == exir_ops .edge .dim_order_ops ._to_dim_order_copy .default
313- ):
314- dim_order_dtype = get_first_fake_tensor (user ).dtype
315- if dim_order_dtype .is_complex or dim_order_dtype .is_floating_point :
316- return False
317- else :
318- return False
319- return True
320- return (
321- ArmAnnotationInfo .CUSTOM_META_KEY in node .meta .get ("custom" , {})
322- and ArmAnnotationInfo (
323- node .meta ["custom" ][ArmAnnotationInfo .CUSTOM_META_KEY ]
324- ).quantized
325- )
326-
327359 def is_node_supported (
328360 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
329361 ) -> bool :
330362
331363 if node .target in (* DQ_OPS , * Q_OPS ):
332364 return True
333365
334- if not self . _is_quantized (node ):
366+ if not is_quantized (node ):
335367 self .reporter .report_reject (
336368 node , "Node was not marked as quantized in the Arm backend."
337369 )
0 commit comments