Skip to content

Commit e1ac7ea

Browse files
authored
Lookahead decoding on static attention
Differential Revision: D76741091 Pull Request resolved: #12276
1 parent 83dc127 commit e1ac7ea

File tree

2 files changed

+277
-10
lines changed

2 files changed

+277
-10
lines changed

examples/models/llama/static_attention.py

Lines changed: 214 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import logging
12
from abc import ABC, abstractmethod
3+
from collections import defaultdict, deque
24
from typing import Any, Callable, Dict, List, Optional, Tuple
35

46
import torch
@@ -14,6 +16,7 @@
1416
from executorch.examples.models.llama.rope import Rope
1517

1618

19+
logger = logging.getLogger(__name__)
1720
_CacheMap = Dict[str, torch.Tensor]
1821
# Key and value caches are kept separate so the key caches can be kept transposed.
1922
_InputCacheState = Tuple[_CacheMap, _CacheMap]
@@ -174,6 +177,24 @@ def unmask(self, new_unmasked_len):
174177

175178

176179
class StaticAttentionIOManager:
180+
class NGramCache:
181+
def __init__(self, max_size):
182+
self.cache = deque()
183+
self.max_size = max_size
184+
185+
def add(self, x):
186+
if x in self.cache:
187+
return
188+
if len(self.cache) == self.max_size:
189+
self.cache.popleft()
190+
self.cache.append(x)
191+
192+
def __iter__(self):
193+
return iter(self.cache)
194+
195+
def __str__(self):
196+
return str(self.cache)
197+
177198
def __init__(
178199
self,
179200
config: ModelArgs,
@@ -266,12 +287,143 @@ def decode(
266287
new_tokens = [init_token]
267288
for _ in range(n):
268289
y = self._run_once(model, new_tokens[-1:])[0]
269-
new_tokens.append(y[:, :1, :].argmax().item())
290+
new_tokens.append(y[:, :1, ...].argmax().item())
270291
if new_tokens[-1] in stop_tokens:
271292
break
272293

273294
return new_tokens
274295

296+
def lookahead_decode( # noqa: C901
297+
self,
298+
model: Callable[..., Any],
299+
init_token: int,
300+
n: int,
301+
ngram_size: int,
302+
window_size: int,
303+
n_verifications: int,
304+
stop_tokens: Optional[List[int]] = None,
305+
ngram_caches: Optional[Dict[int, "StaticAttentionIOManager.NGramCache"]] = None,
306+
):
307+
if self.cache_full:
308+
raise RuntimeError("KV cache is full.")
309+
310+
if (ngram_size - 1) * (window_size + n_verifications) > self.input_len:
311+
raise RuntimeError(
312+
"Lookahead decoding setting not compatible with input length."
313+
f" input_len = {self.input_len},"
314+
f" ngram_size = {ngram_size},"
315+
f" window_size = {window_size},"
316+
f" n_verifications = {n_verifications}"
317+
)
318+
319+
stop_tokens = stop_tokens or []
320+
if ngram_caches is None:
321+
ngram_caches = defaultdict(
322+
lambda: StaticAttentionIOManager.NGramCache(n_verifications)
323+
)
324+
325+
self.mask.tensor[:, :, self.cache_len :] = self._get_lookahead_decoding_mask(
326+
ngram_size, window_size, n_verifications
327+
)
328+
logger.debug("Lookahead decoding mask: ")
329+
for i in range(self.input_len):
330+
logger.debug(
331+
" ".join(
332+
("X" if x == 0.0 else " ")
333+
for x in self.mask.tensor[0][i][self.cache_len :]
334+
)
335+
)
336+
337+
pos_offsets = self._get_lookahead_position_offsets(
338+
ngram_size, window_size, n_verifications
339+
)
340+
341+
verification_offset = max(window_size * (ngram_size - 1), 1)
342+
new_tokens = [init_token]
343+
x = [init_token] * self.input_len
344+
inference_cnt = 0
345+
while len(new_tokens) < n + 1:
346+
# Update verification branch with cached n-grams.
347+
cache = ngram_caches[x[0]]
348+
for i, ngram in enumerate(cache):
349+
for j, token in enumerate(ngram):
350+
x[verification_offset + i * (ngram_size - 1) + j] = token
351+
352+
y, attn_updates = self._run_once(
353+
model,
354+
x,
355+
non_padded_len=1,
356+
freqs_cos_override=self.freqs_cos[pos_offsets + self.pos],
357+
freqs_sin_override=self.freqs_sin[pos_offsets + self.pos],
358+
)
359+
inference_cnt += 1
360+
# Only supports greedy decoding for now.
361+
y = y[0].argmax(dim=-1).tolist()
362+
new_tokens.append(y[0])
363+
logger.debug(f"{self.pos}: x = {x[0]}, y = {y[0]}")
364+
if new_tokens[-1] in stop_tokens:
365+
break
366+
367+
# Collect new n-grams.
368+
for i in range(window_size):
369+
key = x[i]
370+
suffix = []
371+
for j in range(1, ngram_size - 1):
372+
suffix.append(x[i + j * window_size])
373+
suffix.append(y[i + window_size * (ngram_size - 2)])
374+
ngram_caches[key].add(suffix)
375+
376+
# Verification.
377+
longest_match = []
378+
matched_branch = None
379+
for i in range(n_verifications):
380+
match = [y[0]]
381+
j = 0
382+
# for j in range(ngram_size - 1):
383+
while (
384+
j < ngram_size - 1
385+
and x[verification_offset + (ngram_size - 1) * i + j] == match[-1]
386+
):
387+
match.append(y[verification_offset + (ngram_size - 1) * i + j])
388+
j += 1
389+
if len(match) - 1 > len(longest_match):
390+
longest_match = match[1:]
391+
matched_branch = i
392+
393+
if matched_branch is not None:
394+
logger.debug(
395+
f"Matched {len(longest_match)} additional tokens from n-grams: {longest_match}"
396+
)
397+
for stop in stop_tokens:
398+
if stop in longest_match:
399+
longest_match = longest_match[: longest_match.index(stop) + 1]
400+
401+
new_tokens.extend(longest_match)
402+
403+
# Update KV caches and attention mask for the additional matched tokens.
404+
branch_offset = verification_offset + (ngram_size - 1) * matched_branch
405+
self._update_states(
406+
attn_updates,
407+
update_pos=branch_offset,
408+
update_len=len(longest_match),
409+
)
410+
411+
# Update lookahead branch.
412+
for i in range(ngram_size - 2):
413+
for j in range(window_size):
414+
x[window_size * i + j] = x[window_size * (i + 1) + j]
415+
for j in range(window_size):
416+
x[window_size * (ngram_size - 2) + j] = y[
417+
window_size * (ngram_size - 2) + j
418+
]
419+
420+
x[0] = new_tokens[-1]
421+
422+
logger.info(
423+
f"Generated {len(new_tokens) - 1} tokens with {inference_cnt} inference(s)."
424+
)
425+
return new_tokens
426+
275427
def _run_once(
276428
self,
277429
model: Callable[..., Any],
@@ -330,6 +482,67 @@ def _update_states(self, attn_updates, update_pos, update_len):
330482
)
331483
self.pos += update_len
332484

485+
def _get_lookahead_decoding_mask(
486+
self, ngram_size: int, window_size: int, n_verifications: int
487+
) -> torch.Tensor:
488+
mask = torch.full((self.input_len, self.input_len), self.mask_val)
489+
mask[0][0] = 0.0
490+
491+
lookahead_submask = torch.triu(
492+
torch.full((window_size, window_size), self.mask_val),
493+
diagonal=1,
494+
)
495+
for i in range(ngram_size - 1):
496+
offset = window_size * i
497+
mask[offset : offset + window_size, :window_size] = lookahead_submask
498+
for j in range(1, i + 1):
499+
mask[
500+
offset : offset + window_size,
501+
window_size * j : window_size * (j + 1),
502+
].fill_diagonal_(0.0)
503+
504+
verification_offset = max(window_size * (ngram_size - 1), 1)
505+
verification_submask = torch.triu(
506+
torch.full((ngram_size - 1, ngram_size - 1), self.mask_val),
507+
diagonal=1,
508+
)
509+
for i in range(n_verifications):
510+
mask[
511+
verification_offset
512+
+ i * (ngram_size - 1) : verification_offset
513+
+ (i + 1) * (ngram_size - 1),
514+
verification_offset
515+
+ i * (ngram_size - 1) : verification_offset
516+
+ (i + 1) * (ngram_size - 1),
517+
] = verification_submask
518+
mask[verification_offset:, :1] = 0.0
519+
520+
return mask
521+
522+
def _get_lookahead_position_offsets(
523+
self, ngram_size: int, window_size: int, n_verifications: int
524+
) -> torch.Tensor:
525+
# Input position offsets, used for indexing RoPE frequencies.
526+
pos_offsets = torch.zeros(self.input_len, dtype=torch.int32)
527+
idx = 0
528+
# Lookahead branches: [i + 0, i + 1, ..., i + window_size - 1] for time i.
529+
if window_size > 0:
530+
for i in range(ngram_size - 1):
531+
for j in range(window_size):
532+
pos_offsets[idx] = i + j
533+
idx += 1
534+
else:
535+
pos_offsets[0] = 0
536+
idx += 1
537+
538+
# Verification branches: [1, 2, ..., ngram_size - 1].
539+
for _ in range(n_verifications):
540+
for j in range(1, ngram_size):
541+
pos_offsets[idx] = j
542+
idx += 1
543+
544+
return pos_offsets
545+
333546

334547
class _Rope(nn.Module):
335548
def __init__(self, use_hf_rope):

examples/models/llama/tests/test_static_attention.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from collections import defaultdict
23

34
import torch
45
from executorch.examples.models.llama.attention import AttentionMHA
@@ -164,15 +165,7 @@ def test_with_style(style):
164165
test_with_style("shift_pointer")
165166
test_with_style("smart_mask")
166167

167-
def test_within_transformer(self):
168-
config = ModelArgs(
169-
dim=64,
170-
n_heads=4,
171-
n_kv_heads=2,
172-
max_seq_len=24,
173-
n_layers=4,
174-
vocab_size=128,
175-
)
168+
def _get_test_transformers(self, config):
176169
mha_transformer = construct_transformer(config).eval()
177170

178171
config.attention_type = "static"
@@ -183,6 +176,18 @@ def test_within_transformer(self):
183176
):
184177
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
185178

179+
return mha_transformer, static_transformer
180+
181+
def test_within_transformer(self):
182+
config = ModelArgs(
183+
dim=64,
184+
n_heads=4,
185+
n_kv_heads=2,
186+
max_seq_len=24,
187+
n_layers=4,
188+
vocab_size=128,
189+
)
190+
mha_transformer, static_transformer = self._get_test_transformers(config)
186191
x = torch.randint(config.vocab_size, (1, config.max_seq_len))
187192
expected = mha_transformer(x)
188193

@@ -204,3 +209,52 @@ def test_with_style(style):
204209

205210
test_with_style("shift_pointer")
206211
test_with_style("smart_mask")
212+
213+
def test_lookahead_decode(self):
214+
config = ModelArgs(
215+
dim=64,
216+
n_heads=4,
217+
n_kv_heads=2,
218+
max_seq_len=128,
219+
n_layers=4,
220+
vocab_size=128,
221+
generate_full_logits=True,
222+
)
223+
_, static_transformer = self._get_test_transformers(config)
224+
225+
input_len = 32
226+
cache_len = config.max_seq_len - input_len
227+
prefill_input = torch.randint(config.vocab_size, (input_len,))
228+
ref_mgr = StaticAttentionIOManager(config, input_len, cache_len)
229+
lookahead_mgr = StaticAttentionIOManager(config, input_len, cache_len)
230+
231+
next_tok = (
232+
ref_mgr.prefill(static_transformer, prefill_input.tolist())[0][-1]
233+
.argmax()
234+
.item()
235+
)
236+
ref_output = ref_mgr.decode(static_transformer, next_tok, 50)
237+
238+
ngram_size = 3
239+
window_size = 8
240+
n_verifications = 8
241+
ngram_caches = defaultdict(
242+
lambda: StaticAttentionIOManager.NGramCache(n_verifications)
243+
)
244+
for _ in range(2): # run twice, first run will populates the cache
245+
lookahead_mgr.reset()
246+
next_tok = (
247+
lookahead_mgr.prefill(static_transformer, prefill_input.tolist())[0][-1]
248+
.argmax()
249+
.item()
250+
)
251+
lookahead_output = lookahead_mgr.lookahead_decode(
252+
static_transformer,
253+
next_tok,
254+
50,
255+
ngram_size=ngram_size,
256+
window_size=window_size,
257+
n_verifications=n_verifications,
258+
ngram_caches=ngram_caches,
259+
)
260+
self.assertEqual(lookahead_output[: len(ref_output)], ref_output)

0 commit comments

Comments
 (0)