Skip to content

Commit ee15d55

Browse files
mcr229facebook-github-bot
authored andcommitted
Don't use source_partitions for quantizing cat
Summary: Source partitions aren't always reliable. There was model that I recently saw where cat nodes weren't being quantized(likely because they were derived from a different source partition). Moving towards checking all cat nodes should be more reliable Differential Revision: D78138409
1 parent edf25c4 commit ee15d55

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,22 +1002,16 @@ def _annotate_cat(
10021002
quantization_config: Optional[QuantizationConfig],
10031003
filter_fn: Optional[Callable[[Node], bool]] = None,
10041004
) -> Optional[list[list[Node]]]:
1005-
cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
1006-
cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values()))
10071005
annotated_partitions = []
1008-
for cat_partition in cat_partitions:
1009-
cat_node = cat_partition.output_nodes[0]
1006+
for cat_node in gm.graph.nodes:
1007+
if cat_node.target != torch.ops.aten.cat.default:
1008+
continue
1009+
10101010
if _is_annotated([cat_node]):
10111011
continue
10121012

1013-
if cat_node.target != torch.ops.aten.cat.default:
1014-
# TODO: change this to AnnotationException
1015-
raise Exception( # noqa: TRY002
1016-
f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}"
1017-
" please check if you are calling the correct capture API"
1018-
)
10191013

1020-
annotated_partitions.append(cat_partition.nodes)
1014+
annotated_partitions.append(cat_node.all_input_nodes)
10211015

10221016
input_act_qspec = get_input_act_qspec(quantization_config)
10231017
inputs = cat_node.args[0]

0 commit comments

Comments
 (0)