Skip to content

Commit c80ce1c

Browse files
committed
Remove has_full_logits from llama runner
1 parent eeeeb8a commit c80ce1c

File tree

3 files changed

+2
-14
lines changed

3 files changed

+2
-14
lines changed

examples/models/llama/runner/eager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from executorch.examples.models.llama.export_llama_lib import (
1414
_prepare_for_llama_export,
1515
build_args_parser as _build_args_parser,
16-
TORCHTUNE_DEFINED_MODELS,
1716
)
1817
from executorch.examples.models.llama.runner.generation import LlamaRunner
1918
from executorch.extension.llm.export.builder import LLMEdgeManager
@@ -33,7 +32,6 @@ def __init__(self, args):
3332
max_batch_size=1,
3433
use_kv_cache=args.use_kv_cache,
3534
vocab_size=params["vocab_size"],
36-
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
3735
device="cuda" if torch.cuda.is_available() else "cpu",
3836
)
3937
manager: LLMEdgeManager = _prepare_for_llama_export(args)

examples/models/llama/runner/generation.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def __init__(
5353
max_batch_size: int,
5454
use_kv_cache: bool,
5555
vocab_size: int,
56-
has_full_logits: bool = False,
5756
device: str = "cpu",
5857
):
5958
"""
@@ -64,14 +63,12 @@ def __init__(
6463
max_batch_size: max batch size.
6564
use_kv_cache: whether to use a KV cache.
6665
vocab_size: number of items in the vocab.
67-
has_full_logits: whether the model returns the full logits or only returns the last logit.
6866
device: device to run the runner on.
6967
"""
7068
self.max_seq_len = max_seq_len
7169
self.max_batch_size = max_batch_size
7270
self.use_kv_cache = use_kv_cache
7371
self.tokenizer = get_tokenizer(tokenizer_path)
74-
self.has_full_logits = has_full_logits
7572
self.device = device
7673
assert vocab_size == self.tokenizer.n_words
7774

@@ -102,10 +99,7 @@ def generate( # noqa: C901
10299
),
103100
)
104101

105-
if self.has_full_logits:
106-
current_token = next_token(logits[:, -1, :], temperature, top_p)
107-
else:
108-
current_token = next_token(logits, temperature, top_p)
102+
current_token = next_token(logits, temperature, top_p)
109103
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
110104
tokens = prompt_tokens + [current_token]
111105

@@ -127,10 +121,7 @@ def generate( # noqa: C901
127121
)
128122

129123
# If the logits aren't already clipped to only contain the last logit, clip them.
130-
if self.has_full_logits:
131-
current_token = next_token(logits[:, -1, :], temperature, top_p)
132-
else:
133-
current_token = next_token(logits, temperature, top_p)
124+
current_token = next_token(logits, temperature, top_p)
134125
tokens.append(current_token)
135126

136127
if current_token == self.tokenizer.eos_id or (

examples/models/llama/runner/native.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __init__(self, args):
4141
max_batch_size=1,
4242
use_kv_cache=args.kv_cache,
4343
vocab_size=params["vocab_size"],
44-
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
4544
)
4645
self.model = _load_for_executorch(args.pte)
4746

0 commit comments

Comments
 (0)