|
| 1 | +import logging |
1 | 2 | from abc import ABC, abstractmethod
|
| 3 | +from collections import defaultdict, deque |
2 | 4 | from typing import Any, Callable, Dict, List, Optional, Tuple
|
3 | 5 |
|
4 | 6 | import torch
|
|
14 | 16 | from executorch.examples.models.llama.rope import Rope
|
15 | 17 |
|
16 | 18 |
|
| 19 | +logger = logging.getLogger(__name__) |
17 | 20 | _CacheMap = Dict[str, torch.Tensor]
|
18 | 21 | # Key and value caches are kept separate so the key caches can be kept transposed.
|
19 | 22 | _InputCacheState = Tuple[_CacheMap, _CacheMap]
|
@@ -174,6 +177,24 @@ def unmask(self, new_unmasked_len):
|
174 | 177 |
|
175 | 178 |
|
176 | 179 | class StaticAttentionIOManager:
|
| 180 | + class NGramCache: |
| 181 | + def __init__(self, max_size): |
| 182 | + self.cache = deque() |
| 183 | + self.max_size = max_size |
| 184 | + |
| 185 | + def add(self, x): |
| 186 | + if x in self.cache: |
| 187 | + return |
| 188 | + if len(self.cache) == self.max_size: |
| 189 | + self.cache.popleft() |
| 190 | + self.cache.append(x) |
| 191 | + |
| 192 | + def __iter__(self): |
| 193 | + return iter(self.cache) |
| 194 | + |
| 195 | + def __str__(self): |
| 196 | + return str(self.cache) |
| 197 | + |
177 | 198 | def __init__(
|
178 | 199 | self,
|
179 | 200 | config: ModelArgs,
|
@@ -266,12 +287,143 @@ def decode(
|
266 | 287 | new_tokens = [init_token]
|
267 | 288 | for _ in range(n):
|
268 | 289 | y = self._run_once(model, new_tokens[-1:])[0]
|
269 |
| - new_tokens.append(y[:, :1, :].argmax().item()) |
| 290 | + new_tokens.append(y[:, :1, ...].argmax().item()) |
270 | 291 | if new_tokens[-1] in stop_tokens:
|
271 | 292 | break
|
272 | 293 |
|
273 | 294 | return new_tokens
|
274 | 295 |
|
| 296 | + def lookahead_decode( # noqa: C901 |
| 297 | + self, |
| 298 | + model: Callable[..., Any], |
| 299 | + init_token: int, |
| 300 | + n: int, |
| 301 | + ngram_size: int, |
| 302 | + window_size: int, |
| 303 | + n_verifications: int, |
| 304 | + stop_tokens: Optional[List[int]] = None, |
| 305 | + ngram_caches: Optional[Dict[int, "StaticAttentionIOManager.NGramCache"]] = None, |
| 306 | + ): |
| 307 | + if self.cache_full: |
| 308 | + raise RuntimeError("KV cache is full.") |
| 309 | + |
| 310 | + if (ngram_size - 1) * (window_size + n_verifications) > self.input_len: |
| 311 | + raise RuntimeError( |
| 312 | + "Lookahead decoding setting not compatible with input length." |
| 313 | + f" input_len = {self.input_len}," |
| 314 | + f" ngram_size = {ngram_size}," |
| 315 | + f" window_size = {window_size}," |
| 316 | + f" n_verifications = {n_verifications}" |
| 317 | + ) |
| 318 | + |
| 319 | + stop_tokens = stop_tokens or [] |
| 320 | + if ngram_caches is None: |
| 321 | + ngram_caches = defaultdict( |
| 322 | + lambda: StaticAttentionIOManager.NGramCache(n_verifications) |
| 323 | + ) |
| 324 | + |
| 325 | + self.mask.tensor[:, :, self.cache_len :] = self._get_lookahead_decoding_mask( |
| 326 | + ngram_size, window_size, n_verifications |
| 327 | + ) |
| 328 | + logger.debug("Lookahead decoding mask: ") |
| 329 | + for i in range(self.input_len): |
| 330 | + logger.debug( |
| 331 | + " ".join( |
| 332 | + ("X" if x == 0.0 else " ") |
| 333 | + for x in self.mask.tensor[0][i][self.cache_len :] |
| 334 | + ) |
| 335 | + ) |
| 336 | + |
| 337 | + pos_offsets = self._get_lookahead_position_offsets( |
| 338 | + ngram_size, window_size, n_verifications |
| 339 | + ) |
| 340 | + |
| 341 | + verification_offset = max(window_size * (ngram_size - 1), 1) |
| 342 | + new_tokens = [init_token] |
| 343 | + x = [init_token] * self.input_len |
| 344 | + inference_cnt = 0 |
| 345 | + while len(new_tokens) < n + 1: |
| 346 | + # Update verification branch with cached n-grams. |
| 347 | + cache = ngram_caches[x[0]] |
| 348 | + for i, ngram in enumerate(cache): |
| 349 | + for j, token in enumerate(ngram): |
| 350 | + x[verification_offset + i * (ngram_size - 1) + j] = token |
| 351 | + |
| 352 | + y, attn_updates = self._run_once( |
| 353 | + model, |
| 354 | + x, |
| 355 | + non_padded_len=1, |
| 356 | + freqs_cos_override=self.freqs_cos[pos_offsets + self.pos], |
| 357 | + freqs_sin_override=self.freqs_sin[pos_offsets + self.pos], |
| 358 | + ) |
| 359 | + inference_cnt += 1 |
| 360 | + # Only supports greedy decoding for now. |
| 361 | + y = y[0].argmax(dim=-1).tolist() |
| 362 | + new_tokens.append(y[0]) |
| 363 | + logger.debug(f"{self.pos}: x = {x[0]}, y = {y[0]}") |
| 364 | + if new_tokens[-1] in stop_tokens: |
| 365 | + break |
| 366 | + |
| 367 | + # Collect new n-grams. |
| 368 | + for i in range(window_size): |
| 369 | + key = x[i] |
| 370 | + suffix = [] |
| 371 | + for j in range(1, ngram_size - 1): |
| 372 | + suffix.append(x[i + j * window_size]) |
| 373 | + suffix.append(y[i + window_size * (ngram_size - 2)]) |
| 374 | + ngram_caches[key].add(suffix) |
| 375 | + |
| 376 | + # Verification. |
| 377 | + longest_match = [] |
| 378 | + matched_branch = None |
| 379 | + for i in range(n_verifications): |
| 380 | + match = [y[0]] |
| 381 | + j = 0 |
| 382 | + # for j in range(ngram_size - 1): |
| 383 | + while ( |
| 384 | + j < ngram_size - 1 |
| 385 | + and x[verification_offset + (ngram_size - 1) * i + j] == match[-1] |
| 386 | + ): |
| 387 | + match.append(y[verification_offset + (ngram_size - 1) * i + j]) |
| 388 | + j += 1 |
| 389 | + if len(match) - 1 > len(longest_match): |
| 390 | + longest_match = match[1:] |
| 391 | + matched_branch = i |
| 392 | + |
| 393 | + if matched_branch is not None: |
| 394 | + logger.debug( |
| 395 | + f"Matched {len(longest_match)} additional tokens from n-grams: {longest_match}" |
| 396 | + ) |
| 397 | + for stop in stop_tokens: |
| 398 | + if stop in longest_match: |
| 399 | + longest_match = longest_match[: longest_match.index(stop) + 1] |
| 400 | + |
| 401 | + new_tokens.extend(longest_match) |
| 402 | + |
| 403 | + # Update KV caches and attention mask for the additional matched tokens. |
| 404 | + branch_offset = verification_offset + (ngram_size - 1) * matched_branch |
| 405 | + self._update_states( |
| 406 | + attn_updates, |
| 407 | + update_pos=branch_offset, |
| 408 | + update_len=len(longest_match), |
| 409 | + ) |
| 410 | + |
| 411 | + # Update lookahead branch. |
| 412 | + for i in range(ngram_size - 2): |
| 413 | + for j in range(window_size): |
| 414 | + x[window_size * i + j] = x[window_size * (i + 1) + j] |
| 415 | + for j in range(window_size): |
| 416 | + x[window_size * (ngram_size - 2) + j] = y[ |
| 417 | + window_size * (ngram_size - 2) + j |
| 418 | + ] |
| 419 | + |
| 420 | + x[0] = new_tokens[-1] |
| 421 | + |
| 422 | + logger.info( |
| 423 | + f"Generated {len(new_tokens) - 1} tokens with {inference_cnt} inference(s)." |
| 424 | + ) |
| 425 | + return new_tokens |
| 426 | + |
275 | 427 | def _run_once(
|
276 | 428 | self,
|
277 | 429 | model: Callable[..., Any],
|
@@ -330,6 +482,67 @@ def _update_states(self, attn_updates, update_pos, update_len):
|
330 | 482 | )
|
331 | 483 | self.pos += update_len
|
332 | 484 |
|
| 485 | + def _get_lookahead_decoding_mask( |
| 486 | + self, ngram_size: int, window_size: int, n_verifications: int |
| 487 | + ) -> torch.Tensor: |
| 488 | + mask = torch.full((self.input_len, self.input_len), self.mask_val) |
| 489 | + mask[0][0] = 0.0 |
| 490 | + |
| 491 | + lookahead_submask = torch.triu( |
| 492 | + torch.full((window_size, window_size), self.mask_val), |
| 493 | + diagonal=1, |
| 494 | + ) |
| 495 | + for i in range(ngram_size - 1): |
| 496 | + offset = window_size * i |
| 497 | + mask[offset : offset + window_size, :window_size] = lookahead_submask |
| 498 | + for j in range(1, i + 1): |
| 499 | + mask[ |
| 500 | + offset : offset + window_size, |
| 501 | + window_size * j : window_size * (j + 1), |
| 502 | + ].fill_diagonal_(0.0) |
| 503 | + |
| 504 | + verification_offset = max(window_size * (ngram_size - 1), 1) |
| 505 | + verification_submask = torch.triu( |
| 506 | + torch.full((ngram_size - 1, ngram_size - 1), self.mask_val), |
| 507 | + diagonal=1, |
| 508 | + ) |
| 509 | + for i in range(n_verifications): |
| 510 | + mask[ |
| 511 | + verification_offset |
| 512 | + + i * (ngram_size - 1) : verification_offset |
| 513 | + + (i + 1) * (ngram_size - 1), |
| 514 | + verification_offset |
| 515 | + + i * (ngram_size - 1) : verification_offset |
| 516 | + + (i + 1) * (ngram_size - 1), |
| 517 | + ] = verification_submask |
| 518 | + mask[verification_offset:, :1] = 0.0 |
| 519 | + |
| 520 | + return mask |
| 521 | + |
| 522 | + def _get_lookahead_position_offsets( |
| 523 | + self, ngram_size: int, window_size: int, n_verifications: int |
| 524 | + ) -> torch.Tensor: |
| 525 | + # Input position offsets, used for indexing RoPE frequencies. |
| 526 | + pos_offsets = torch.zeros(self.input_len, dtype=torch.int32) |
| 527 | + idx = 0 |
| 528 | + # Lookahead branches: [i + 0, i + 1, ..., i + window_size - 1] for time i. |
| 529 | + if window_size > 0: |
| 530 | + for i in range(ngram_size - 1): |
| 531 | + for j in range(window_size): |
| 532 | + pos_offsets[idx] = i + j |
| 533 | + idx += 1 |
| 534 | + else: |
| 535 | + pos_offsets[0] = 0 |
| 536 | + idx += 1 |
| 537 | + |
| 538 | + # Verification branches: [1, 2, ..., ngram_size - 1]. |
| 539 | + for _ in range(n_verifications): |
| 540 | + for j in range(1, ngram_size): |
| 541 | + pos_offsets[idx] = j |
| 542 | + idx += 1 |
| 543 | + |
| 544 | + return pos_offsets |
| 545 | + |
333 | 546 |
|
334 | 547 | class _Rope(nn.Module):
|
335 | 548 | def __init__(self, use_hf_rope):
|
|
0 commit comments