Skip to content

Commit e72ac0b

Browse files
committed
Update on "[llama-mm] Enable kv cache for MultiHeadAttention"
Summary: Change `MultiHeadAttention` in `extension/llm/modules` to support KV cache. Only enable eager but not export yet. Test Plan: Unit test Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent d4e8c6e commit e72ac0b

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

extension/llm/modules/test/test_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,11 @@ def test_attention_export(self):
131131
et_mha_ep = torch.export.export(
132132
self.et_mha,
133133
(self.x, self.x),
134-
kwargs=None,
134+
kwargs={"input_pos": self.input_pos},
135135
dynamic_shapes=self.dynamic_shapes,
136136
)
137-
et_res = et_mha_ep.module()(self.x, self.x)
138-
tt_res = self.tt_mha(self.x, self.x)
137+
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
138+
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
139139
self.assertTrue(torch.allclose(et_res, tt_res))
140140

141141
# TODO: KV cache.
@@ -149,7 +149,7 @@ def test_attention_executorch(self):
149149
et_mha_ep = torch.export.export(
150150
self.et_mha,
151151
(self.x, self.x),
152-
kwargs=None,
152+
kwargs={"input_pos": self.input_pos},
153153
dynamic_shapes=self.dynamic_shapes,
154154
)
155155
et_program = to_edge(
@@ -159,8 +159,8 @@ def test_attention_executorch(self):
159159
runtime = Runtime.get()
160160
program = runtime.load_program(et_program.buffer)
161161
method = program.load_method("forward")
162-
et_res = method.execute((self.x, self.x))
163-
tt_res = self.tt_mha(self.x, self.x)
162+
et_res = method.execute((self.x, self.x, self.input_pos))
163+
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
164164

165165
self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06))
166166

0 commit comments

Comments
 (0)