@@ -519,6 +519,39 @@ def test_multiple_pools(
519519 idx += 1
520520 self .assertEqual (graph_module .meta ["non_const_buffer_sizes" ], expected_bufsizes )
521521
522+ def test_mutation_not_double_allocated (self ) -> None :
523+ class Simple (torch .nn .Module ):
524+ def __init__ (self ) -> None :
525+ super ().__init__ ()
526+ self .register_buffer ("constant" , torch .ones (5 , 5 ))
527+
528+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
529+ self .constant .add_ (1 )
530+ return x - self .constant
531+
532+ model = Simple ()
533+ inputs = (torch .ones (5 , 5 ),)
534+
535+ et = to_edge (export (model , inputs , strict = True )).to_executorch ()
536+
537+ # 0 and 11 should refer to the same tensor. 0 is the input, 11 is the output of copy_
538+ self .assertEqual (
539+ et .executorch_program .execution_plan [0 ]
540+ .values [0 ]
541+ .val .allocation_info .memory_offset_low ,
542+ et .executorch_program .execution_plan [0 ]
543+ .values [11 ]
544+ .val .allocation_info .memory_offset_low ,
545+ )
546+ self .assertEqual (
547+ et .executorch_program .execution_plan [0 ]
548+ .values [0 ]
549+ .val .allocation_info .memory_offset_high ,
550+ et .executorch_program .execution_plan [0 ]
551+ .values [11 ]
552+ .val .allocation_info .memory_offset_high ,
553+ )
554+
522555 def test_constants_not_memory_planned (self ) -> None :
523556 class Simple (torch .nn .Module ):
524557 def __init__ (self ) -> None :
0 commit comments