@@ -168,18 +168,19 @@ def text_completion(
168168
169169 def chat_completion (
170170 self ,
171+ max_seq_len : int ,
171172 temperature : float = 0.6 ,
172173 top_p : float = 0.9 ,
174+ show_progress : bool = False ,
173175 ) -> List [int ]:
174176 """
175177 Perform multi-turn chat with the language model.
176178
177179 Args:
178- prompt (str ): Text prompt for completion .
180+ max_seq_len (int ): Maximum number of tokens to generate for each prompt .
179181 temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
180182 top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
181- echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
182-
183+ show_progress (bool, optional): Flag indicating whether to show number of tokens generated.
183184 Returns:
184185 Generated list of tokens.
185186
@@ -188,20 +189,26 @@ def chat_completion(
188189 """
189190 exit_prompt = "exit"
190191 tokens = []
192+ pre_stop_token = []
191193 prompt = input ("Me: " )
192194 while prompt and prompt != exit_prompt :
193195 print ("LLM: " , end = "" , flush = True )
194- new_tokens = self .generate (
195- prompt_tokens = self .tokenizer .encode (
196- self ._format_prompt (prompt ), bos = True , eos = False
197- ),
198- max_seq_len = self .max_seq_len ,
196+ prompt_tokens = self .tokenizer .encode (
197+ self ._format_prompt (prompt ), bos = True , eos = False
198+ )
199+ generated_tokens = self .generate (
200+ prompt_tokens = pre_stop_token + prompt_tokens ,
201+ max_seq_len = max_seq_len ,
199202 temperature = temperature ,
200203 top_p = top_p ,
201- echo = True ,
202- pos_base = len (tokens ) - 1 if len (tokens ) > 0 else 0
204+ echo = False ,
205+ pos_base = len (tokens ) - 1 if len (tokens ) > 0 else 0 ,
203206 )
204- tokens .extend (new_tokens )
207+ pre_stop_token = generated_tokens [- 1 :]
208+ tokens .extend (prompt_tokens )
209+ tokens .extend (generated_tokens )
210+ if show_progress :
211+ print (f"[Generated { len (tokens )} tokens]" )
205212 prompt = input ("Me: " )
206213 return tokens
207214
0 commit comments