|
| 1 | +import lightning as L |
| 2 | +import torch |
| 3 | +from typing import Optional |
| 4 | +from lit_llama import LLaMA |
| 5 | + |
| 6 | +@torch.no_grad() |
| 7 | +def generate( |
| 8 | + model: LLaMA, |
| 9 | + idx: torch.Tensor, |
| 10 | + max_new_tokens: int, |
| 11 | + *, |
| 12 | + max_seq_length: Optional[int] = None, |
| 13 | + temperature: float = 1.0, |
| 14 | + top_k: Optional[int] = None, |
| 15 | + eos_id: Optional[int] = None, |
| 16 | +) -> torch.Tensor: |
| 17 | + """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. |
| 18 | +
|
| 19 | + The implementation of this function is modified from A. Karpathy's nanoGPT. |
| 20 | +
|
| 21 | + Args: |
| 22 | + model: The model to use. |
| 23 | + idx: Tensor of shape (T) with indices of the prompt sequence. |
| 24 | + max_new_tokens: The number of new tokens to generate. |
| 25 | + max_seq_length: The maximum sequence length allowed. |
| 26 | + temperature: Scales the predicted logits by 1 / temperature |
| 27 | + top_k: If specified, only sample among the tokens with the k highest probabilities |
| 28 | + eos_id: If specified, stop generating any more token once the <eos> token is triggered |
| 29 | + """ |
| 30 | + # create an empty tensor of the expected final shape and fill in the current tokens |
| 31 | + T = idx.size(0) |
| 32 | + T_new = T + max_new_tokens |
| 33 | + if max_seq_length is None: |
| 34 | + max_seq_length = min(T_new, model.config.block_size) |
| 35 | + |
| 36 | + device, dtype = idx.device, idx.dtype |
| 37 | + # create an empty tensor of the expected final shape and fill in the current tokens |
| 38 | + empty = torch.empty(T_new, dtype=dtype, device=device) |
| 39 | + empty[:T] = idx |
| 40 | + idx = empty |
| 41 | + input_pos = torch.arange(0, T, device=device) |
| 42 | + |
| 43 | + if idx.device.type == "xla": |
| 44 | + import torch_xla.core.xla_model as xm |
| 45 | + |
| 46 | + xm.mark_step() |
| 47 | + |
| 48 | + # generate max_new_tokens tokens |
| 49 | + for _ in range(max_new_tokens): |
| 50 | + x = idx.index_select(0, input_pos).view(1, -1) |
| 51 | + |
| 52 | + # forward |
| 53 | + logits = model(x, max_seq_length, input_pos) |
| 54 | + logits = logits[0, -1] / temperature |
| 55 | + |
| 56 | + # optionally crop the logits to only the top k options |
| 57 | + if top_k is not None: |
| 58 | + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| 59 | + logits = torch.where(logits < v[[-1]], -float("Inf"), logits) |
| 60 | + |
| 61 | + probs = torch.nn.functional.softmax(logits, dim=-1) |
| 62 | + idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) |
| 63 | + |
| 64 | + # advance |
| 65 | + input_pos = input_pos[-1:] + 1 |
| 66 | + |
| 67 | + if idx.device.type == "xla": |
| 68 | + xm.mark_step() |
| 69 | + |
| 70 | + # concatenate the new generation |
| 71 | + idx = idx.index_copy(0, input_pos, idx_next) |
| 72 | + |
| 73 | + # if <eos> token is triggered, return the output (stop generation) |
| 74 | + if idx_next == eos_id: |
| 75 | + return idx[:input_pos] # include the EOS token |
| 76 | + |
| 77 | + return idx |
0 commit comments