Skip to content

Commit 960ea70

Browse files
authored
Handle same tensor appearing multiple times in the cat input
Differential Revision: D88439174 Pull Request resolved: #16091
1 parent b765e99 commit 960ea70

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

backends/cadence/aot/memory_constraints.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,10 @@ def is_slice_view(self, node: torch.fx.Node) -> bool:
417417
return not self.constraint.is_alias_of(source_info.source, node)
418418
return False
419419

420+
def has_relative_placement_constraint(self, node: torch.fx.Node) -> bool:
421+
"""Return if `node` already has any relative placement constraint."""
422+
return self.constraint.get_relative_placement_source(node) is not None
423+
420424
# Return true if the cat node performs concatenation along outermost dimension
421425
def is_cat_along_outermost_dim(
422426
self, graph_module: torch.fx.GraphModule, cat_node: torch.fx.Node
@@ -481,6 +485,17 @@ def is_removable_cat_op(
481485
if any(self.is_slice_view(arg) for arg in cat_tensors):
482486
return False
483487

488+
# If any of the tensors already has a relative placement constraint,
489+
# we cannot add a new constraint for this cat without conflicting.
490+
# This can happen when a tensor is used in multiple cat operations.
491+
if any(self.has_relative_placement_constraint(arg) for arg in cat_tensors):
492+
return False
493+
494+
# If the same tensor appears multiple times in the cat inputs,
495+
# we cannot place it at multiple different offsets relative to the output.
496+
if len(cat_tensors) != len(set(cat_tensors)):
497+
return False
498+
484499
# Many ops in HiFi require the input to be aligned to 8-byte boundary.
485500
# If the cat is not the graph's output, then ensure that the relative
486501
# offset of any concatenated non-placeholder tensor is a multiple of

backends/cadence/aot/tests/test_memory_passes.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,110 @@ def test_cat_then_cat(self) -> None:
947947
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
948948
self.verify_nop_memory_alloc(graph_module)
949949

950+
def test_cat_with_duplicate_input_tensor(self) -> None:
951+
"""
952+
Test that cat is NOT optimized when the same tensor appears multiple
953+
times in the cat input list. This is because we cannot place the same
954+
tensor at multiple different offsets relative to the output.
955+
"""
956+
builder = GraphBuilder()
957+
x = builder.placeholder("x", torch.ones(3, 6, dtype=torch.float32))
958+
to_add_to_x = builder.call_operator(
959+
op=exir_ops.edge.aten.full.default,
960+
args=([3, 6], 123.0),
961+
kwargs={"dtype": torch.float32},
962+
)
963+
add_x = builder.call_operator(
964+
op=exir_ops.edge.aten.add.Tensor,
965+
args=(x, to_add_to_x),
966+
)
967+
pre_created_output = builder.call_operator(
968+
op=exir_ops.edge.aten.full.default,
969+
args=([6, 6], 0.0),
970+
kwargs={"dtype": torch.float32},
971+
)
972+
# Same tensor (add_x) appears twice in the cat inputs
973+
cat = builder.call_operator(
974+
op=torch.ops.aten.cat.out,
975+
args=([add_x, add_x],),
976+
kwargs={"dim": 0, "out": pre_created_output},
977+
)
978+
builder.output([cat])
979+
original = builder.get_graph_module()
980+
graph_module = self.run_memory_planning(original)
981+
graph_module.graph.eliminate_dead_code()
982+
983+
# Assert that cat op is NOT optimized away since the same tensor
984+
# appears multiple times in the input list
985+
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
986+
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
987+
self.verify_nop_memory_alloc(graph_module)
988+
989+
def test_cat_with_tensor_having_existing_constraint(self) -> None:
990+
"""
991+
Test that the second cat is NOT optimized when a tensor already has a
992+
relative placement constraint from a previous cat operation.
993+
"""
994+
builder = GraphBuilder()
995+
x = builder.placeholder("x", torch.ones(8, 8, dtype=torch.float32))
996+
to_add = builder.call_operator(
997+
op=exir_ops.edge.aten.full.default,
998+
args=([8, 8], 1.0),
999+
kwargs={"dtype": torch.float32},
1000+
)
1001+
x1 = builder.call_operator(
1002+
op=exir_ops.edge.aten.add.Tensor,
1003+
args=(x, to_add),
1004+
)
1005+
x2 = builder.call_operator(
1006+
op=exir_ops.edge.aten.add.Tensor,
1007+
args=(x1, to_add),
1008+
)
1009+
x3 = builder.call_operator(
1010+
op=exir_ops.edge.aten.add.Tensor,
1011+
args=(x2, to_add),
1012+
)
1013+
# First cat: cat(x1, x2) - this will give x1 and x2 relative placement constraints
1014+
pre_created_output1 = builder.call_operator(
1015+
op=exir_ops.edge.aten.full.default,
1016+
args=([16, 8], 0.0),
1017+
kwargs={"dtype": torch.float32},
1018+
)
1019+
cat1 = builder.call_operator(
1020+
op=torch.ops.aten.cat.out,
1021+
args=([x1, x2],),
1022+
kwargs={"dim": 0, "out": pre_created_output1},
1023+
)
1024+
# Second cat: cat(x2, x3) - x2 already has a constraint from cat1,
1025+
# so this cat cannot be optimized
1026+
pre_created_output2 = builder.call_operator(
1027+
op=exir_ops.edge.aten.full.default,
1028+
args=([16, 8], 0.0),
1029+
kwargs={"dtype": torch.float32},
1030+
)
1031+
cat2 = builder.call_operator(
1032+
op=torch.ops.aten.cat.out,
1033+
args=([x2, x3],),
1034+
kwargs={"dim": 0, "out": pre_created_output2},
1035+
)
1036+
# Use both cat results to keep them alive
1037+
graph_output = builder.call_operator(
1038+
op=exir_ops.edge.aten.add.Tensor,
1039+
args=(cat1, cat2),
1040+
)
1041+
builder.output([graph_output])
1042+
original = builder.get_graph_module()
1043+
graph_module = self.run_memory_planning(
1044+
original, opt_level=3, alloc_graph_input=False
1045+
)
1046+
graph_module.graph.eliminate_dead_code()
1047+
1048+
# The first cat should be optimized to _cat_nop, but the second cat
1049+
# cannot be optimized because x2 already has a relative placement constraint
1050+
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
1051+
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
1052+
self.verify_nop_memory_alloc(graph_module)
1053+
9501054
def test_view_for_unallocated_output(self) -> None:
9511055
builder = GraphBuilder()
9521056
x = builder.placeholder("x", torch.ones(3, 5, dtype=torch.float32))

0 commit comments

Comments
 (0)