@@ -139,3 +139,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
139139 fully_partitioned ,
140140 "Graph should be fully partitioned with all operators having the same tag" ,
141141 )
142+
143+ def test_unused_constant_tagging (self ):
144+ """
145+ Test that constant nodes without users are properly tagged with delegation_tag.
146+
147+ When a graph contains constants (parameters, buffers, or lifted tensor constants)
148+ that are not used by any operations, the CUDA partitioner should still tag them
149+ with the delegation_tag. This ensures all constant data is properly handled during
150+ delegation, even if they have no users in the graph.
151+ """
152+
153+ class ModuleWithUnusedConst (torch .nn .Module ):
154+ def __init__ (self ):
155+ super ().__init__ ()
156+ # Register a buffer that won't be used in forward
157+ self .register_buffer ("unused_buffer" , torch .randn (10 , 10 ))
158+ # Also register a used parameter
159+ self .weight = torch .nn .Parameter (torch .randn (5 , 5 ))
160+
161+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
162+ # Only use the weight parameter, not the unused_buffer
163+ return x + self .weight
164+
165+ module = ModuleWithUnusedConst ()
166+ inputs = (torch .randn (5 , 5 ),)
167+
168+ # Get partition result
169+ partition_result = self ._get_partition_result (module , inputs )
170+
171+ # Find all placeholder nodes (these represent constants, parameters, buffers, and inputs)
172+ constant_placeholders = []
173+ input_placeholders = []
174+
175+ for node in partition_result .tagged_exported_program .graph .nodes :
176+ if node .op == "placeholder" :
177+ # Check if this is a constant (param, buffer, or lifted tensor constant)
178+ from torch ._export .utils import (
179+ is_buffer ,
180+ is_lifted_tensor_constant ,
181+ is_param ,
182+ )
183+
184+ is_constant = (
185+ is_param (partition_result .tagged_exported_program , node )
186+ or is_buffer (partition_result .tagged_exported_program , node )
187+ or is_lifted_tensor_constant (
188+ partition_result .tagged_exported_program , node
189+ )
190+ )
191+
192+ if is_constant :
193+ constant_placeholders .append (node )
194+ else :
195+ input_placeholders .append (node )
196+
197+ # Verify we have constant placeholders
198+ self .assertGreater (
199+ len (constant_placeholders ),
200+ 0 ,
201+ "Expected to find constant placeholders in the graph" ,
202+ )
203+
204+ # Check that all constant placeholders are tagged, including unused ones
205+ untagged_constants = []
206+ for node in constant_placeholders :
207+ if "delegation_tag" not in node .meta :
208+ untagged_constants .append (node .name )
209+
210+ self .assertEqual (
211+ len (untagged_constants ),
212+ 0 ,
213+ f"All constant placeholders should be tagged. Found untagged constants: { untagged_constants } " ,
214+ )
215+
216+ # Verify all tagged constants have the expected tag
217+ expected_tag = "tag0"
218+ for node in constant_placeholders :
219+ actual_tag = node .meta .get ("delegation_tag" )
220+ self .assertEqual (
221+ actual_tag ,
222+ expected_tag ,
223+ f"Constant placeholder { node .name } has tag '{ actual_tag } ' but expected '{ expected_tag } '" ,
224+ )
0 commit comments