@@ -1002,22 +1002,15 @@ def _annotate_cat(
10021002 quantization_config : Optional [QuantizationConfig ],
10031003 filter_fn : Optional [Callable [[Node ], bool ]] = None ,
10041004) -> Optional [list [list [Node ]]]:
1005- cat_partitions = get_source_partitions (gm .graph , [torch .cat ], filter_fn )
1006- cat_partitions = list (itertools .chain .from_iterable (cat_partitions .values ()))
10071005 annotated_partitions = []
1008- for cat_partition in cat_partitions :
1009- cat_node = cat_partition .output_nodes [0 ]
1010- if _is_annotated ([cat_node ]):
1006+ for cat_node in gm .graph .nodes :
1007+ if cat_node .target != torch .ops .aten .cat .default :
10111008 continue
10121009
1013- if cat_node .target != torch .ops .aten .cat .default :
1014- # TODO: change this to AnnotationException
1015- raise Exception ( # noqa: TRY002
1016- f"Expected cat node: torch.ops.aten.cat.default, but found { cat_node .target } "
1017- " please check if you are calling the correct capture API"
1018- )
1010+ if _is_annotated ([cat_node ]):
1011+ continue
10191012
1020- annotated_partitions .append (cat_partition . nodes )
1013+ annotated_partitions .append (cat_node . all_input_nodes )
10211014
10221015 input_act_qspec = get_input_act_qspec (quantization_config )
10231016 inputs = cat_node .args [0 ]
0 commit comments