@@ -232,7 +232,9 @@ def partition(
232232
233233 # Check Owning Program still owns all constant data
234234 owning_program = delegated .exported_program ()
235- self .assertEqual (len (owning_program .state_dict ), 3 )
235+ self .assertEqual (
236+ len (owning_program .state_dict ) + len (owning_program .constants ), 3
237+ )
236238 self .assertEqual (len (owning_program .graph_signature .buffers ), 2 )
237239 self .assertEqual (len (owning_program .graph_signature .parameters ), 1 )
238240
@@ -321,7 +323,7 @@ def partition(
321323 delegated .exported_program ().graph_module , lowered_module_node .name
322324 )
323325 delegated_ep = lower_module .original_module
324- self .assertEqual (len (delegated_ep .state_dict ), 3 )
326+ self .assertEqual (len (delegated_ep .state_dict ) + len ( delegated_ep . constants ) , 3 )
325327 self .assertEqual (len (delegated_ep .graph_signature .buffers ), 2 )
326328 self .assertEqual (len (delegated_ep .graph_signature .parameters ), 1 )
327329
@@ -375,7 +377,9 @@ def partition(
375377
376378 # Check Owning Program still owns only buffers
377379 owning_program = delegated .exported_program ()
378- self .assertEqual (len (owning_program .state_dict ), 2 )
380+ self .assertEqual (
381+ len (owning_program .state_dict ) + len (owning_program .constants ), 2
382+ )
379383 self .assertEqual (len (owning_program .graph_signature .buffers ), 2 )
380384 self .assertEqual (len (owning_program .graph_signature .parameters ), 0 )
381385
0 commit comments