@@ -53,7 +53,6 @@ def __init__(
5353 max_batch_size : int ,
5454 use_kv_cache : bool ,
5555 vocab_size : int ,
56- has_full_logits : bool = False ,
5756 device : str = "cpu" ,
5857 ):
5958 """
@@ -64,14 +63,12 @@ def __init__(
6463 max_batch_size: max batch size.
6564 use_kv_cache: whether to use a KV cache.
6665 vocab_size: number of items in the vocab.
67- has_full_logits: whether the model returns the full logits or only returns the last logit.
6866 device: device to run the runner on.
6967 """
7068 self .max_seq_len = max_seq_len
7169 self .max_batch_size = max_batch_size
7270 self .use_kv_cache = use_kv_cache
7371 self .tokenizer = get_tokenizer (tokenizer_path )
74- self .has_full_logits = has_full_logits
7572 self .device = device
7673 assert vocab_size == self .tokenizer .n_words
7774
@@ -102,10 +99,7 @@ def generate( # noqa: C901
10299 ),
103100 )
104101
105- if self .has_full_logits :
106- current_token = next_token (logits [:, - 1 , :], temperature , top_p )
107- else :
108- current_token = next_token (logits , temperature , top_p )
102+ current_token = next_token (logits , temperature , top_p )
109103 print (f"{ self .tokenizer .decode_token (current_token )} " , end = "" , flush = True )
110104 tokens = prompt_tokens + [current_token ]
111105
@@ -127,10 +121,7 @@ def generate( # noqa: C901
127121 )
128122
129123 # If the logits aren't already clipped to only contain the last logit, clip them.
130- if self .has_full_logits :
131- current_token = next_token (logits [:, - 1 , :], temperature , top_p )
132- else :
133- current_token = next_token (logits , temperature , top_p )
124+ current_token = next_token (logits , temperature , top_p )
134125 tokens .append (current_token )
135126
136127 if current_token == self .tokenizer .eos_id or (
0 commit comments