@@ -139,3 +139,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
139139            fully_partitioned ,
140140            "Graph should be fully partitioned with all operators having the same tag" ,
141141        )
142+ 
143+     def  test_unused_constant_tagging (self ):
144+         """ 
145+         Test that constant nodes without users are properly tagged with delegation_tag. 
146+ 
147+         When a graph contains constants (parameters, buffers, or lifted tensor constants) 
148+         that are not used by any operations, the CUDA partitioner should still tag them 
149+         with the delegation_tag. This ensures all constant data is properly handled during 
150+         delegation, even if they have no users in the graph. 
151+         """ 
152+ 
153+         class  ModuleWithUnusedConst (torch .nn .Module ):
154+             def  __init__ (self ):
155+                 super ().__init__ ()
156+                 # Register a buffer that won't be used in forward 
157+                 self .register_buffer ("unused_buffer" , torch .randn (10 , 10 ))
158+                 # Also register a used parameter 
159+                 self .weight  =  torch .nn .Parameter (torch .randn (5 , 5 ))
160+ 
161+             def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
162+                 # Only use the weight parameter, not the unused_buffer 
163+                 return  x  +  self .weight 
164+ 
165+         module  =  ModuleWithUnusedConst ()
166+         inputs  =  (torch .randn (5 , 5 ),)
167+ 
168+         # Get partition result 
169+         partition_result  =  self ._get_partition_result (module , inputs )
170+ 
171+         # Find all placeholder nodes (these represent constants, parameters, buffers, and inputs) 
172+         constant_placeholders  =  []
173+         input_placeholders  =  []
174+ 
175+         for  node  in  partition_result .tagged_exported_program .graph .nodes :
176+             if  node .op  ==  "placeholder" :
177+                 # Check if this is a constant (param, buffer, or lifted tensor constant) 
178+                 from  torch ._export .utils  import  (
179+                     is_buffer ,
180+                     is_lifted_tensor_constant ,
181+                     is_param ,
182+                 )
183+ 
184+                 is_constant  =  (
185+                     is_param (partition_result .tagged_exported_program , node )
186+                     or  is_buffer (partition_result .tagged_exported_program , node )
187+                     or  is_lifted_tensor_constant (
188+                         partition_result .tagged_exported_program , node 
189+                     )
190+                 )
191+ 
192+                 if  is_constant :
193+                     constant_placeholders .append (node )
194+                 else :
195+                     input_placeholders .append (node )
196+ 
197+         # Verify we have constant placeholders 
198+         self .assertGreater (
199+             len (constant_placeholders ),
200+             0 ,
201+             "Expected to find constant placeholders in the graph" ,
202+         )
203+ 
204+         # Check that all constant placeholders are tagged, including unused ones 
205+         untagged_constants  =  []
206+         for  node  in  constant_placeholders :
207+             if  "delegation_tag"  not  in   node .meta :
208+                 untagged_constants .append (node .name )
209+ 
210+         self .assertEqual (
211+             len (untagged_constants ),
212+             0 ,
213+             f"All constant placeholders should be tagged. Found untagged constants: { untagged_constants }  " ,
214+         )
215+ 
216+         # Verify all tagged constants have the expected tag 
217+         expected_tag  =  "tag0" 
218+         for  node  in  constant_placeholders :
219+             actual_tag  =  node .meta .get ("delegation_tag" )
220+             self .assertEqual (
221+                 actual_tag ,
222+                 expected_tag ,
223+                 f"Constant placeholder { node .name }   has tag '{ actual_tag }  ' but expected '{ expected_tag }  '" ,
224+             )
0 commit comments