Skip to content

Commit 220aec8

Browse files
committed
Use mask instead of cond for attention conditional logic
1 parent ca32105 commit 220aec8

File tree

3 files changed

+44
-49
lines changed

3 files changed

+44
-49
lines changed

extension/llm/modules/attention.py

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def forward(
246246
# x has shape [b, s_x, d]
247247
# y has shape [b, s_y, d]
248248
b, s_x, _ = x.shape
249+
s_y = y.shape[1] if y is not None else 0
249250

250251
# q has shape [b, s_x, num_heads * head_dim]
251252
q = self.q_proj(x)
@@ -262,9 +263,16 @@ def forward(
262263
if self.q_norm is not None:
263264
q = self.q_norm(q)
264265

265-
def calculate_kv(y):
266+
if y is None:
267+
if self.kv_cache is None:
268+
raise ValueError(
269+
"Must provide y input or use kv_cache to enable streaming decoding"
270+
)
271+
k = self.kv_cache.k_cache
272+
v = self.kv_cache.v_cache
273+
else:
266274
# Update k and v shape, positional embeddings, and normalization
267-
s_y = y.shape[1]
275+
268276
# k has shape [b, s_y, num_kv_heads * head_dim]
269277
# v has shape [b, s_y, num_kv_heads * head_dim]
270278
k = self.k_proj(y)
@@ -280,37 +288,12 @@ def calculate_kv(y):
280288
# Normalize k
281289
if self.k_norm is not None:
282290
k = self.k_norm(k)
283-
return k, v
284-
285-
def true_fn(y):
286-
kv_cache = self.kv_cache.clone()
287-
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
288291

289-
def false_fn(y):
290-
k, v = calculate_kv(y)
291-
kv_cache = self.kv_cache.clone()
292-
kv_cache.update(k, v)
293-
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
294-
295-
# If kv cache is None, we expect y to be provided
296-
if self.kv_cache is None:
297-
assert (
298-
y is not None
299-
), "Must provide y input or use kv_cache to enable streaming decoding"
300-
k, v = calculate_kv(y)
301-
else:
302-
# Expecting the k, v returning here to be the same size of self.kv_cache
303-
# In eager, we expect this predicate to specialize. In export, this will
304-
# become a SymBool so it's not specialized.
305-
k, v, cache_pos = torch.cond(
306-
torch.isnan(y).all().item(), true_fn, false_fn, (y,)
307-
)
308292
# Update key-value cache
309-
self.kv_cache.k_cache.copy_(k)
310-
self.kv_cache.v_cache.copy_(v)
311-
self.kv_cache.cache_pos.copy_(cache_pos)
293+
if self.kv_cache is not None and self.cache_enabled:
294+
k, v = self.kv_cache.update(k, v)
312295

313-
output = self._sdpa(q, k, v, b, s_x, mask=mask)
296+
output = self._sdpa(q, k, v, b, s_x)
314297
return self.output_proj(output)
315298

316299

@@ -352,17 +335,25 @@ def forward(
352335
# View + expand + reshape bring num_kv_heads to num_heads for k and v
353336
# to match q.
354337

355-
# [bsz, n_h, s, h_d]
356-
q = q.transpose(1, 2)
357-
k = k.transpose(1, 2)
358-
v = v.transpose(1, 2)
338+
# k: [bsz, seq_len, n_kv, 1, h_d]
339+
# v: [bsz, seq_len, n_kv, 1, h_d]
340+
k = k.view(bsz, -1, self.num_kv_heads, 1, self.head_dim)
341+
v = v.view(bsz, -1, self.num_kv_heads, 1, self.head_dim)
359342

360343
# Expand the key and value tensors to have the same shape
361344
# as the query tensor by copying values across the relevant dim
362345
if self.num_heads != self.num_kv_heads:
363-
expand_shape = (-1, -1, self.q_per_kv, -1, -1)
364-
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
365-
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
346+
k = k.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim)
347+
v = v.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim)
348+
349+
# [bsz, s, n_h, h_d]
350+
k = k.reshape(bsz, -1, self.num_heads, self.head_dim)
351+
v = v.reshape(bsz, -1, self.num_heads, self.head_dim)
352+
353+
# [bsz, n_h, s, h_d]
354+
q = q.transpose(1, 2)
355+
k = k.transpose(1, 2)
356+
v = v.transpose(1, 2)
366357

367358
output = self._attention_fn(
368359
q,

extension/llm/modules/kv_cache.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,13 @@ def update(
111111
v_out = self.v_cache
112112

113113
if self.transpose_cache:
114-
k_out[:, :, self.cache_pos[:seq_len]] = k_val
115-
v_out[:, :, self.cache_pos[:seq_len]] = v_val
114+
pos_mask = torch.arange(k_out.shape[2]) < seq_len
115+
k_out[:, :, self.cache_pos[pos_mask]] = k_val
116+
v_out[:, :, self.cache_pos[pos_mask]] = v_val
116117
else:
117-
k_out[:, self.cache_pos[:seq_len]] = k_val
118-
v_out[:, self.cache_pos[:seq_len]] = v_val
118+
pos_mask = torch.arange(k_out.shape[1]) < seq_len
119+
k_out[:, self.cache_pos[pos_mask], :] = k_val
120+
v_out[:, self.cache_pos[pos_mask], :] = v_val
119121

120122
# forward cache_pos seq_len positions along
121123
# cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
@@ -124,7 +126,8 @@ def update(
124126
# this allows us to track the current position in the cache
125127
# after the last update in a compile-friendly way without any dynamism
126128
# e.g. relying on an int size tracker, or re-creating cache_pos every time
127-
self.cache_pos.add_(seq_len)
129+
mask = (seq_len > 0) * 1
130+
self.cache_pos.add_(seq_len * mask)
128131

129132
return k_out, v_out
130133

extension/llm/modules/test/test_attention.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,28 +219,29 @@ def test_attention_torch_cond_eager(self):
219219
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
220220
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
221221

222-
# mask
223222
mask = self.causal_mask[self.input_pos, :]
224-
# First run
223+
# First run, for the value of the second parameter in the forward, it doesn't matter
224+
# whether it is the same as the first (self attention) or if it is different (cross
225+
# attention).
225226
et_res = self.et_mha(
226227
self.x, self.x, mask=mask, input_pos=self.input_pos
227-
) # Self attention with input pos.
228+
)
228229
tt_res = self.tt_mha(
229230
self.x, self.x, mask=mask, input_pos=self.input_pos
230-
) # Self attention with input pos.
231+
)
231232

232233
self.assertTrue(torch.allclose(et_res, tt_res))
233234

234-
# Second run test kv cache read. Input pos is [10, 11, ..., 19]
235+
# Second run test kv cache read. Input pos is [10, 11, ..., 19].
235236
next_input_pos = torch.arange(10, 20).unsqueeze(0)
236237

237238
empty_y = torch.full_like(self.x, torch.nan)
238239
mask = self.causal_mask[next_input_pos, :]
239240
et_res = self.et_mha(
240241
self.x, empty_y, mask=mask, input_pos=next_input_pos
241-
) # Self attention with input pos.
242+
) # Cross attention with no y input, ET uses a tensor of empty values.
242243
tt_res = self.tt_mha(
243244
self.x, None, mask=mask, input_pos=next_input_pos
244-
) # Self attention with input pos.
245+
) # Cross attention with no y input, TorchTune uses None.
245246

246247
assert_close(et_res, tt_res)

0 commit comments

Comments
 (0)