Skip to content

Commit b7763e9

Browse files
committed
Add test
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 32b31a9 commit b7763e9

File tree

3 files changed

+47
-8
lines changed

3 files changed

+47
-8
lines changed

extension/llm/modules/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def false_fn(y):
310310
self.kv_cache.v_cache.copy_(v)
311311
self.kv_cache.cache_pos.copy_(cache_pos)
312312

313-
output = self._sdpa(q, k, v, b, s_x)
313+
output = self._sdpa(q, k, v, b, s_x, mask=mask)
314314
return self.output_proj(output)
315315

316316

extension/llm/modules/kv_cache.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,16 @@ def update(
131131
def clone(self) -> "KVCache":
132132
"""Create a clone of the KVCache."""
133133
if self.transpose_cache:
134-
max_seq_len = self.k_cache.shape[1]
135-
num_kv_heads = self.k_cache.shape[2]
136-
else:
137-
max_seq_len = self.k_cache.shape[2]
138134
num_kv_heads = self.k_cache.shape[1]
135+
else:
136+
num_kv_heads = self.k_cache.shape[2]
139137
clone = KVCache(
140138
batch_size=self.batch_size,
141-
max_seq_len=max_seq_len,
139+
max_seq_len=self.max_seq_len,
142140
num_kv_heads=num_kv_heads,
143141
head_dim=self.k_cache.shape[3],
144142
dtype=self.k_cache.dtype,
143+
transpose_cache=self.transpose_cache,
145144
)
146145
clone.k_cache.copy_(self.k_cache)
147146
clone.v_cache.copy_(self.v_cache)

extension/llm/modules/test/test_attention.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def setUp(self):
2727
torch.manual_seed(0)
2828
# Constants
2929
self.embed_dim = 2048
30-
self.num_heads = 32
30+
self.num_heads = 8
3131
self.num_kv_heads = 8
3232
self.head_dim = 64
3333
self.max_seq_len = 128
@@ -46,7 +46,9 @@ def setUp(self):
4646
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
4747
)
4848
self.v_proj.weight.requires_grad = False
49-
self.output_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False)
49+
self.output_proj = torch.nn.Linear(
50+
self.num_heads * self.head_dim, self.embed_dim, bias=False
51+
)
5052
self.pos_embeddings = Llama3ScaledRoPE(
5153
dim=self.head_dim,
5254
max_seq_len=self.max_seq_len,
@@ -92,6 +94,12 @@ def setUp(self):
9294
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
9395
{0: torch.export.Dim.STATIC, 1: seq_len_dim},
9496
)
97+
self.causal_mask = torch.tril(
98+
torch.ones(
99+
size=(self.max_seq_len, self.max_seq_len),
100+
dtype=torch.bool,
101+
)
102+
)
95103

96104
def test_attention_eager(self):
97105
et_res = self.et_mha(self.x, self.x) # Self attention.
@@ -197,3 +205,35 @@ def test_attention_executorch(self):
197205
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
198206

199207
assert_close(et_res[0], tt_res)
208+
209+
def test_attention_torch_cond_eager(self):
210+
# Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
211+
# For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
212+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
213+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
214+
215+
# mask
216+
mask = self.causal_mask[self.input_pos, :]
217+
# First run
218+
et_res = self.et_mha(
219+
self.x, self.x, mask=mask, input_pos=self.input_pos
220+
) # Self attention with input pos.
221+
tt_res = self.tt_mha(
222+
self.x, self.x, mask=mask, input_pos=self.input_pos
223+
) # Self attention with input pos.
224+
225+
self.assertTrue(torch.allclose(et_res, tt_res))
226+
227+
# Second run test kv cache read. Input pos is [10, 11, ..., 19]
228+
next_input_pos = torch.arange(10, 20).unsqueeze(0)
229+
230+
empty_y = torch.full_like(self.x, torch.nan)
231+
mask = self.causal_mask[next_input_pos, :]
232+
et_res = self.et_mha(
233+
self.x, empty_y, mask=mask, input_pos=next_input_pos
234+
) # Self attention with input pos.
235+
tt_res = self.tt_mha(
236+
self.x, None, mask=mask, input_pos=next_input_pos
237+
) # Self attention with input pos.
238+
239+
assert_close(et_res, tt_res)

0 commit comments

Comments
 (0)