-
Notifications
You must be signed in to change notification settings - Fork 722
Fix executorch kv cache incompatibility with to_executorch lowering #7279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
aac90a0
917fb0d
46ea733
5db136c
9cdfb43
9e68531
925409d
2a3fe8b
61101c2
4ee95d3
e297c9b
8145cda
73591f1
a2b7ee3
93f99ad
69e36fb
5c53856
9d84a42
6fe376d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| import typing | ||
| import unittest | ||
| from contextlib import contextmanager | ||
| from copy import deepcopy | ||
| from typing import List, Optional, Tuple | ||
|
|
||
| import executorch.exir as exir | ||
|
|
@@ -31,6 +32,7 @@ | |
| from executorch.exir.error import InternalError | ||
| from executorch.exir.passes import MemoryPlanningPass | ||
| from executorch.exir.passes.constant_prop_pass import constant_prop_pass | ||
| from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass | ||
| from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass | ||
| from executorch.exir.print_program import pretty_print, print_program # noqa | ||
| from executorch.exir.schema import ( | ||
|
|
@@ -56,6 +58,7 @@ | |
| from executorch.extension.pybindings.portable_lib import ( | ||
| _load_for_executorch_from_buffer, | ||
| ) | ||
| from executorch.runtime import Runtime | ||
|
|
||
| from functorch.experimental import control_flow | ||
| from torch import nn | ||
|
|
@@ -243,6 +246,56 @@ def forward(self, x): | |
| ) | ||
| self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null) | ||
|
|
||
| def test_initialized_mutable_buffer(self): | ||
| """Test that mutable buffers can hold meaningful initialized state.""" | ||
|
|
||
| class TestModule(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| # Mutable buffer with non-empty initial state. | ||
| self.register_buffer("cache_pos", torch.arange(0, 10)) | ||
|
|
||
| def forward(self, x): | ||
| self.cache_pos.add_(1) | ||
| return self.cache_pos | ||
|
|
||
| m = TestModule() | ||
| example_inputs = (torch.ones(10),) | ||
| ep = torch.export.export(m, example_inputs) | ||
| edge = to_edge( | ||
| ep, | ||
| compile_config=EdgeCompileConfig( | ||
| _check_ir_validity=False, | ||
| ), | ||
| ) | ||
|
|
||
| # Save a copy of the edge program since to_executorch is | ||
| # stateful to some degree. | ||
| edge_copy = deepcopy(edge) | ||
| et_config = ExecutorchBackendConfig( | ||
| passes=[InitializedMutableBufferPass(["cache_pos"])], | ||
| ) | ||
| et_program_init_pass = edge.to_executorch(config=et_config) | ||
| et_program_regular = edge_copy.to_executorch() | ||
|
|
||
| runtime = Runtime.get() | ||
| program_init_pass = runtime.load_program(et_program_init_pass.buffer) | ||
| method_init_pass = program_init_pass.load_method("forward") | ||
|
|
||
| program_regular = runtime.load_program(et_program_regular.buffer) | ||
| method_regular = program_regular.load_method("forward") | ||
|
|
||
| # Test that the mutable buffer is initialized. | ||
| torch.allclose( | ||
| method_init_pass.execute((example_inputs))[0], torch.arange(1, 11) | ||
| ) | ||
| # Test that the mutable buffer is uninitialized and starts with default zeros, | ||
| # we test equality with torch.ones because of the mutation += 1 in the model forward. | ||
| torch.allclose( | ||
| method_regular.execute((example_inputs))[0], | ||
| torch.ones(10, dtype=torch.int64), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, it's because in the forward of the model we do |
||
| ) | ||
|
|
||
| def test_int_list_input(self): | ||
| class M(torch.nn.Module): | ||
| def forward(self, x, y, z): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
dbort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
|
|
||
| from typing import List | ||
|
|
||
| from executorch.exir.pass_base import ExportPass | ||
|
|
||
|
|
||
| class InitializedMutableBufferPass(ExportPass): | ||
| """ | ||
| If the buffer has the name "cache_pos", such as in an kv_cache | ||
jackzhxng marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| module with `self.register_buffer("cache_pos", torch.arange(10))`, | ||
| mark it with a custom tag which later is used by the emitter to | ||
| flag spec.const to True, which provides the mutable buffer with | ||
| an initialized state. | ||
| """ | ||
|
|
||
| def __init__(self, patterns: List[str]) -> None: | ||
| super().__init__() | ||
| self.patterns = patterns | ||
|
|
||
| def placeholder(self, name: str, arg, meta): | ||
| for pattern in self.patterns: | ||
| if pattern in name: | ||
| meta["et_init_buffer"] = True | ||
|
|
||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return super().placeholder(name, arg, meta) | ||
Uh oh!
There was an error while loading. Please reload this page.