diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index e78cb5aca90..0a85da56d9b 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -1002,22 +1002,15 @@ def _annotate_cat( quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[list[list[Node]]]: - cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) - cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) annotated_partitions = [] - for cat_partition in cat_partitions: - cat_node = cat_partition.output_nodes[0] - if _is_annotated([cat_node]): + for cat_node in gm.graph.nodes: + if cat_node.target != torch.ops.aten.cat.default: continue - if cat_node.target != torch.ops.aten.cat.default: - # TODO: change this to AnnotationException - raise Exception( # noqa: TRY002 - f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}" - " please check if you are calling the correct capture API" - ) + if _is_annotated([cat_node]): + continue - annotated_partitions.append(cat_partition.nodes) + annotated_partitions.append(cat_node.all_input_nodes) input_act_qspec = get_input_act_qspec(quantization_config) inputs = cat_node.args[0]