|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -from abc import ABC, abstractmethod |
8 | | -from typing import List, Optional, TypedDict |
| 7 | +from typing import List |
9 | 8 |
|
10 | 9 | import torch |
11 | | - |
12 | | -from executorch.extension.llm.tokenizer.utils import get_tokenizer |
13 | | -from executorch.examples.models.llama.runner.generation import LlamaRunner, next_token, sample_top_p |
| 10 | +from executorch.examples.models.llama.runner.generation import LlamaRunner, next_token |
14 | 11 |
|
15 | 12 |
|
16 | 13 | class TorchTuneLlamaRunner(LlamaRunner): |
@@ -54,13 +51,17 @@ def generate( # noqa: C901 |
54 | 51 | mask = self.causal_mask[None, :seq_len] |
55 | 52 | if self.use_kv_cache: |
56 | 53 | logits = self.forward( |
57 | | - tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), |
| 54 | + tokens=torch.tensor( |
| 55 | + [prompt_tokens], dtype=torch.long, device=self.device |
| 56 | + ), |
58 | 57 | input_pos=input_pos, |
59 | 58 | mask=mask, |
60 | 59 | ) |
61 | 60 | else: |
62 | 61 | logits = self.forward( |
63 | | - tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), |
| 62 | + tokens=torch.tensor( |
| 63 | + [prompt_tokens], dtype=torch.long, device=self.device |
| 64 | + ), |
64 | 65 | ) |
65 | 66 |
|
66 | 67 | # Only need the last logit. |
@@ -98,4 +99,3 @@ def generate( # noqa: C901 |
98 | 99 | seq_len += 1 |
99 | 100 |
|
100 | 101 | return tokens if echo else tokens[len(prompt_tokens) :] |
101 | | - |
0 commit comments