Skip to content

Commit aa289ea

Browse files
committed
Fix faulty merge
1 parent a0e33d9 commit aa289ea

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

examples/models/llama/runner/generation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,25 @@ def __init__(
5353
max_batch_size: int,
5454
use_kv_cache: bool,
5555
vocab_size: int,
56+
has_full_logits: bool = False,
5657
device: str = "cpu",
5758
):
59+
"""
60+
Constructor.
61+
Args:
62+
tokenizer_path: path to tokenizer.model file.
63+
max_seq_len: max length of the output sequence, after which the output will be clipped.
64+
max_batch_size: max batch size.
65+
use_kv_cache: whether to use a KV cache.
66+
vocab_size: number of items in the vocab.
67+
has_full_logits: whether the model returns the full logits or only returns the last logit.
68+
device: device to run the runner on.
69+
"""
5870
self.max_seq_len = max_seq_len
5971
self.max_batch_size = max_batch_size
6072
self.use_kv_cache = use_kv_cache
6173
self.tokenizer = get_tokenizer(tokenizer_path)
74+
self.has_full_logits = has_full_logits
6275
self.device = device
6376
assert vocab_size == self.tokenizer.n_words
6477

0 commit comments

Comments
 (0)