| 
9 | 9 | import typing  | 
10 | 10 | import unittest  | 
11 | 11 | from contextlib import contextmanager  | 
 | 12 | +from copy import deepcopy  | 
12 | 13 | from typing import List, Optional, Tuple  | 
13 | 14 | 
 
  | 
14 | 15 | import executorch.exir as exir  | 
 | 
31 | 32 | from executorch.exir.error import InternalError  | 
32 | 33 | from executorch.exir.passes import MemoryPlanningPass  | 
33 | 34 | from executorch.exir.passes.constant_prop_pass import constant_prop_pass  | 
 | 35 | +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass  | 
34 | 36 | from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass  | 
35 | 37 | from executorch.exir.print_program import pretty_print, print_program  # noqa  | 
36 | 38 | from executorch.exir.schema import (  | 
 | 
56 | 58 | from executorch.extension.pybindings.portable_lib import (  | 
57 | 59 |     _load_for_executorch_from_buffer,  | 
58 | 60 | )  | 
 | 61 | +from executorch.runtime import Runtime  | 
59 | 62 | 
 
  | 
60 | 63 | from functorch.experimental import control_flow  | 
61 | 64 | from torch import nn  | 
@@ -243,6 +246,56 @@ def forward(self, x):  | 
243 | 246 |         )  | 
244 | 247 |         self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)  | 
245 | 248 | 
 
  | 
 | 249 | +    def test_initialized_mutable_buffer(self):  | 
 | 250 | +        """Test that mutable buffers can hold meaningful initialized state."""  | 
 | 251 | + | 
 | 252 | +        class TestModule(torch.nn.Module):  | 
 | 253 | +            def __init__(self):  | 
 | 254 | +                super().__init__()  | 
 | 255 | +                # Mutable buffer with non-empty initial state.  | 
 | 256 | +                self.register_buffer("cache_pos", torch.arange(0, 10))  | 
 | 257 | + | 
 | 258 | +            def forward(self, x):  | 
 | 259 | +                self.cache_pos.add_(1)  | 
 | 260 | +                return self.cache_pos  | 
 | 261 | + | 
 | 262 | +        m = TestModule()  | 
 | 263 | +        example_inputs = (torch.ones(10),)  | 
 | 264 | +        ep = torch.export.export(m, example_inputs)  | 
 | 265 | +        edge = to_edge(  | 
 | 266 | +            ep,  | 
 | 267 | +            compile_config=EdgeCompileConfig(  | 
 | 268 | +                _check_ir_validity=False,  | 
 | 269 | +            ),  | 
 | 270 | +        )  | 
 | 271 | + | 
 | 272 | +        # Save a copy of the edge program since to_executorch is  | 
 | 273 | +        # stateful to some degree.  | 
 | 274 | +        edge_copy = deepcopy(edge)  | 
 | 275 | +        et_config = ExecutorchBackendConfig(  | 
 | 276 | +            passes=[InitializedMutableBufferPass(["cache_pos"])],  | 
 | 277 | +        )  | 
 | 278 | +        et_program_init_pass = edge.to_executorch(config=et_config)  | 
 | 279 | +        et_program_regular = edge_copy.to_executorch()  | 
 | 280 | + | 
 | 281 | +        runtime = Runtime.get()  | 
 | 282 | +        program_init_pass = runtime.load_program(et_program_init_pass.buffer)  | 
 | 283 | +        method_init_pass = program_init_pass.load_method("forward")  | 
 | 284 | + | 
 | 285 | +        program_regular = runtime.load_program(et_program_regular.buffer)  | 
 | 286 | +        method_regular = program_regular.load_method("forward")  | 
 | 287 | + | 
 | 288 | +        # Test that the mutable buffer is initialized.  | 
 | 289 | +        torch.allclose(  | 
 | 290 | +            method_init_pass.execute((example_inputs))[0], torch.arange(1, 11)  | 
 | 291 | +        )  | 
 | 292 | +        # Test that the mutable buffer is uninitialized and starts with default zeros,  | 
 | 293 | +        # we test equality with torch.ones because of the mutation += 1 in the model forward.  | 
 | 294 | +        torch.allclose(  | 
 | 295 | +            method_regular.execute((example_inputs))[0],  | 
 | 296 | +            torch.ones(10, dtype=torch.int64),  | 
 | 297 | +        )  | 
 | 298 | + | 
246 | 299 |     def test_int_list_input(self):  | 
247 | 300 |         class M(torch.nn.Module):  | 
248 | 301 |             def forward(self, x, y, z):  | 
 | 
0 commit comments