@@ -57,26 +57,12 @@ def __init__(
5757 max_batch_size : int ,
5858 use_kv_cache : bool ,
5959 vocab_size : int ,
60- has_full_logits : bool = False ,
6160 device : str = "cpu" ,
6261 ):
63- """
64- Constructor.
65-
66- Args:
67- tokenizer_path: path to tokenizer.model file.
68- max_seq_len: max length of the output sequence, after which the output will be clipped.
69- max_batch_size: max batch size.
70- use_kv_cache: whether to use a KV cache.
71- vocab_size: number of items in the vocab.
72- has_full_logits: whether the model returns the full logits or only returns the last logit.
73- device: device to run the runner on.
74- """
7562 self .max_seq_len = max_seq_len
7663 self .max_batch_size = max_batch_size
7764 self .use_kv_cache = use_kv_cache
7865 self .tokenizer = get_tokenizer (tokenizer_path )
79- self .has_full_logits = has_full_logits
8066 self .device = device
8167 assert vocab_size == self .tokenizer .n_words
8268
@@ -95,7 +81,7 @@ def generate( # noqa: C901
9581 top_p : float = 0.9 ,
9682 echo : bool = False ,
9783 ) -> List [int ]:
98- # prefill
84+ # Prefill
9985 logits = self .forward (
10086 tokens = torch .tensor ([prompt_tokens ], dtype = torch .long , device = self .device ),
10187 input_pos = (
@@ -105,11 +91,7 @@ def generate( # noqa: C901
10591 ),
10692 )
10793
108- current_token = next_token (logits [:, - 1 , :], temperature , top_p )
109- if self .has_full_logits :
110- current_token = next_token (logits [:, - 1 , :], temperature , top_p )
111- else :
112- current_token = next_token (logits , temperature , top_p )
94+ current_token = next_token (logits , temperature , top_p )
11395 tokens = prompt_tokens + [current_token ]
11496
11597 i = 0
@@ -129,12 +111,7 @@ def generate( # noqa: C901
129111 tokens = torch .tensor ([tokens ], dtype = torch .long , device = self .device ),
130112 )
131113
132- # If the logits aren't already clipped to only contain the last logit, clip them.
133- if self .has_full_logits :
134- current_token = next_token (logits [:, - 1 , :], temperature , top_p )
135- else :
136- current_token = next_token (logits , temperature , top_p )
137-
114+ current_token = next_token (logits , temperature , top_p )
138115 if current_token == self .tokenizer .eos_id or (
139116 hasattr (self .tokenizer , "stop_tokens" )
140117 and current_token in self .tokenizer .stop_tokens
0 commit comments