@@ -67,12 +67,13 @@ def generate( # noqa: C901
6767 temperature : float = 0.8 ,
6868 top_p : float = 0.9 ,
6969 echo : bool = False ,
70+ pos_base : int = 0 ,
7071 ) -> List [int ]:
7172 # prefill
7273 logits = self .forward (
7374 tokens = torch .tensor ([prompt_tokens ], dtype = torch .long , device = self .device ),
7475 input_pos = (
75- torch .tensor ([0 ], dtype = torch .long , device = self .device )
76+ torch .tensor ([pos_base ], dtype = torch .long , device = self .device )
7677 if self .params .use_kv_cache
7778 else None
7879 ),
@@ -89,7 +90,9 @@ def generate( # noqa: C901
8990 [[current_token ]], dtype = torch .long , device = self .device
9091 ),
9192 input_pos = torch .tensor (
92- [len (tokens ) - 1 ], dtype = torch .long , device = self .device
93+ [pos_base + len (tokens ) - 1 ],
94+ dtype = torch .long ,
95+ device = self .device ,
9396 ),
9497 )
9598 else :
@@ -136,3 +139,49 @@ def text_completion(
136139 top_p = top_p ,
137140 echo = echo ,
138141 )
142+
143+ def chat_completion (
144+ self ,
145+ temperature : float = 0.6 ,
146+ top_p : float = 0.9 ,
147+ ) -> List [int ]:
148+ """
149+ Perform multi-turn chat with the language model.
150+
151+ Args:
152+ prompt (str): Text prompt for completion.
153+ temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
154+ top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
155+ echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
156+
157+ Returns:
158+ Generated list of tokens.
159+
160+ Note:
161+ This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
162+ """
163+ exit_prompt = "exit"
164+ tokens = []
165+ prompt = input ("Me: " )
166+ while prompt and prompt != exit_prompt :
167+ print ("LLM: " , end = "" , flush = True )
168+ new_tokens = self .generate (
169+ prompt_tokens = self .tokenizer .encode (
170+ self ._format_prompt (prompt ), bos = True , eos = False
171+ ),
172+ temperature = temperature ,
173+ top_p = top_p ,
174+ echo = True ,
175+ pos_base = len (tokens ),
176+ )
177+ tokens .extend (new_tokens )
178+ prompt = input ("Me: " )
179+ return tokens
180+
181+ def _format_prompt (self , prompt : str ) -> str :
182+ return f"""
183+ <|begin_of_text|><|start_header_id|>system<|end_header_id|>
184+
185+ You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
186+
187+ { prompt } <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
0 commit comments