Skip to content

Commit d37143a

Browse files
committed
[aoti-et] Tag unused lifted constant placeholders
For lifted constants, if they are not being used, they won't be tagged by `tag_constant_data()` API. We have to manually tag them in cuda partitioner.
1 parent 3485495 commit d37143a

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

backends/cuda/cuda_partitioner.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
PartitionResult,
1717
)
1818
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
19+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
1920
from torch.export.exported_program import ExportedProgram
2021

2122

@@ -56,6 +57,18 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
5657
tag_constant_data(exported_program)
5758
tag_mutated_buffer(exported_program)
5859

60+
# Tag constant placeholders that have no users
61+
# tag_constant_data only tags constants that have users with delegation_tag
62+
# but we need to tag all constants for this partition
63+
for node in exported_program.graph.nodes:
64+
if node.op == "placeholder" and (
65+
is_param(exported_program, node)
66+
or is_buffer(exported_program, node)
67+
or is_lifted_tensor_constant(exported_program, node)
68+
):
69+
if "delegation_tag" not in node.meta:
70+
node.meta["delegation_tag"] = tag
71+
5972
return PartitionResult(
6073
tagged_exported_program=exported_program, partition_tags=partition_tags
6174
)

backends/cuda/tests/test_cuda_partitioner.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
139139
fully_partitioned,
140140
"Graph should be fully partitioned with all operators having the same tag",
141141
)
142+
143+
def test_unused_constant_tagging(self):
144+
"""
145+
Test that constant nodes without users are properly tagged with delegation_tag.
146+
147+
When a graph contains constants (parameters, buffers, or lifted tensor constants)
148+
that are not used by any operations, the CUDA partitioner should still tag them
149+
with the delegation_tag. This ensures all constant data is properly handled during
150+
delegation, even if they have no users in the graph.
151+
"""
152+
153+
class ModuleWithUnusedConst(torch.nn.Module):
154+
def __init__(self):
155+
super().__init__()
156+
# Register a buffer that won't be used in forward
157+
self.register_buffer("unused_buffer", torch.randn(10, 10))
158+
# Also register a used parameter
159+
self.weight = torch.nn.Parameter(torch.randn(5, 5))
160+
161+
def forward(self, x: torch.Tensor) -> torch.Tensor:
162+
# Only use the weight parameter, not the unused_buffer
163+
return x + self.weight
164+
165+
module = ModuleWithUnusedConst()
166+
inputs = (torch.randn(5, 5),)
167+
168+
# Get partition result
169+
partition_result = self._get_partition_result(module, inputs)
170+
171+
# Find all placeholder nodes (these represent constants, parameters, buffers, and inputs)
172+
constant_placeholders = []
173+
input_placeholders = []
174+
175+
for node in partition_result.tagged_exported_program.graph.nodes:
176+
if node.op == "placeholder":
177+
# Check if this is a constant (param, buffer, or lifted tensor constant)
178+
from torch._export.utils import (
179+
is_buffer,
180+
is_lifted_tensor_constant,
181+
is_param,
182+
)
183+
184+
is_constant = (
185+
is_param(partition_result.tagged_exported_program, node)
186+
or is_buffer(partition_result.tagged_exported_program, node)
187+
or is_lifted_tensor_constant(
188+
partition_result.tagged_exported_program, node
189+
)
190+
)
191+
192+
if is_constant:
193+
constant_placeholders.append(node)
194+
else:
195+
input_placeholders.append(node)
196+
197+
# Verify we have constant placeholders
198+
self.assertGreater(
199+
len(constant_placeholders),
200+
0,
201+
"Expected to find constant placeholders in the graph",
202+
)
203+
204+
# Check that all constant placeholders are tagged, including unused ones
205+
untagged_constants = []
206+
for node in constant_placeholders:
207+
if "delegation_tag" not in node.meta:
208+
untagged_constants.append(node.name)
209+
210+
self.assertEqual(
211+
len(untagged_constants),
212+
0,
213+
f"All constant placeholders should be tagged. Found untagged constants: {untagged_constants}",
214+
)
215+
216+
# Verify all tagged constants have the expected tag
217+
expected_tag = "tag0"
218+
for node in constant_placeholders:
219+
actual_tag = node.meta.get("delegation_tag")
220+
self.assertEqual(
221+
actual_tag,
222+
expected_tag,
223+
f"Constant placeholder {node.name} has tag '{actual_tag}' but expected '{expected_tag}'",
224+
)

0 commit comments

Comments
 (0)