1818 FuseQuantizedActivationPass ,
1919)
2020from executorch .backends .arm ._passes .insert_table_ops import TableOps
21+ from executorch .backends .arm .common .annotation_meta import ArmAnnotationInfo
2122from executorch .backends .arm .constants import DQ_OPS , MAX_RANK , Q_OPS
2223from executorch .backends .arm .operator_support .ethos_u55_support import (
2324 EthosU55CastCheck ,
@@ -134,6 +135,7 @@ def tosa_support_factory(
134135 ]
135136
136137 if not tosa_spec .support_float ():
138+ negative_checks .append (CheckArmQuantized (reporter ))
137139 negative_checks .append (CheckProperQuantization (reporter ))
138140 if tosa_spec .is_U55_subset :
139141 negative_checks .append (EthosU55NotSupported (reporter ))
@@ -161,7 +163,6 @@ class TOSAProINTSupportList(OperatorSupportBase):
161163 def is_node_supported (
162164 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
163165 ) -> bool :
164-
165166 return node .op == "call_function" and node .target in TOSA_PRO_INT_SupportList
166167
167168
@@ -174,10 +175,80 @@ class TOSAProFPSupportList(OperatorSupportBase):
174175 def is_node_supported (
175176 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
176177 ) -> bool :
177-
178178 return node .op == "call_function" and node .target in TOSA_PRO_FP_SupportList
179179
180180
181+ class CheckArmQuantized (OperatorSupportBase ):
182+ """
183+ Check if the node was marked as quantized in the Arm backend.
184+ This is used to ensure that nodes that were quantized in the Arm backend
185+ are only partitioned if they are supported by the TOSA backend.
186+ """
187+
188+ def __init__ (self , reporter : WhyNoPartitionReporter ):
189+ self .reporter = reporter
190+
191+ def _is_quantized (self , node : torch .fx .Node ) -> bool :
192+ """Checks if the node is quantized.
193+
194+ A node is considered quantized if at least one criteria is met:
195+ - Its dtype is not floating point or complex => integer
196+ - It is one of the special cases where the node has been created in to_edge, e.g.
197+ .Scalar operations that have been promoted .Tensor operations
198+ where the scalar is replaced by a full op.
199+ - It has been marked as quantized in the ArmAnnotationInfo custom meta.
200+
201+ Args:
202+ node (torch.fx.Node): The FX node to check.
203+
204+ Returns:
205+ bool: True if the node is quantized, False otherwise.
206+ """
207+ node_dtype = get_first_fake_tensor (node ).dtype
208+ if not node_dtype .is_complex and not node_dtype .is_floating_point :
209+ return True
210+ if node .target in (
211+ exir_ops .edge .aten .full_like .default ,
212+ * ComputeConstantOpsAOT .targeted_ops ,
213+ ):
214+ # Special cases where nodes have been created in to_edge, e.g.
215+ # .Scalar operations that have been promoted .Tensor operations
216+ # where the scalar is replaced by a full op.
217+ if all (user .target in Q_OPS for user in node .users ):
218+ return True
219+ for user in node .users :
220+ if (
221+ user .target
222+ == exir_ops .edge .dim_order_ops ._to_dim_order_copy .default
223+ ):
224+ dim_order_dtype = get_first_fake_tensor (user ).dtype
225+ if dim_order_dtype .is_complex or dim_order_dtype .is_floating_point :
226+ return False
227+ else :
228+ return False
229+ return True
230+ return (
231+ ArmAnnotationInfo .CUSTOM_META_KEY in node .meta .get ("custom" , {})
232+ and node .meta ["custom" ][ArmAnnotationInfo .CUSTOM_META_KEY ].quantized
233+ )
234+
235+ def is_node_supported (
236+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
237+ ) -> bool :
238+ if node .op != "call_function" :
239+ return False
240+
241+ if node .target in (* DQ_OPS , * Q_OPS ):
242+ return True
243+
244+ if not self ._is_quantized (node ):
245+ self .reporter .report_reject (
246+ node , "Node was not marked as quantized in the Arm backend."
247+ )
248+ return False
249+ return True
250+
251+
181252class CheckProperQuantization (OperatorSupportBase ):
182253 """
183254 For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize
@@ -350,7 +421,6 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
350421 def is_node_supported (
351422 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
352423 ) -> bool :
353-
354424 vals = node .meta ["val" ]
355425 tensor_list = vals if isinstance (vals , (list , tuple )) else [vals ]
356426
@@ -416,7 +486,6 @@ def is_node_supported(
416486
417487
418488class CheckFloat64Inputs (OperatorSupportBase ):
419-
420489 def __init__ (
421490 self , exported_program : ExportedProgram , reporter : WhyNoPartitionReporter
422491 ):
@@ -426,7 +495,6 @@ def __init__(
426495 def is_node_supported (
427496 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
428497 ) -> bool :
429-
430498 for input_node in node .all_input_nodes :
431499 tensor = get_first_fake_tensor (input_node )
432500 if tensor .dtype == torch .float64 :
0 commit comments