1414import torch .nn .functional as F
1515from executorch .examples .models .llama2 .llama_transformer import ModelArgs
1616
17- from executorch .examples .models .llama2 .tokenizer .tiktoken import (
18- Dialog ,
19- Message ,
20- Tokenizer ,
21- )
17+ from executorch .examples .models .llama2 .tokenizer .tiktoken import Tokenizer
2218from executorch .extension .pybindings .portable_lib import _load_for_executorch
2319
2420
@@ -28,12 +24,6 @@ class CompletionPrediction(TypedDict, total=False):
2824 logprobs : List [float ] # not required
2925
3026
31- class ChatPrediction (TypedDict , total = False ):
32- generation : Message
33- tokens : List [str ] # not required
34- logprobs : List [float ] # not required
35-
36-
3727def sample_top_p (probs , p ):
3828 """
3929 Perform top-p (nucleus) sampling on a probability distribution.
@@ -225,72 +215,6 @@ def text_completion(
225215 ]
226216 return [{"generation" : self .tokenizer .decode (t )} for t in generation_tokens ]
227217
228- def chat_completion (
229- self ,
230- dialogs : List [Dialog ],
231- temperature : float = 0.6 ,
232- top_p : float = 0.9 ,
233- max_gen_len : Optional [int ] = None ,
234- logprobs : bool = False ,
235- ) -> List [ChatPrediction ]:
236- """
237- Generate assistant responses for a list of conversational dialogs using the language generation model.
238-
239- Args:
240- dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
241- temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
242- top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
243- max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
244- If not provided, it's set to the model's maximum sequence length minus 1.
245- logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
246-
247- Returns:
248- List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
249-
250- Raises:
251- AssertionError: If the last message in a dialog is not from the user.
252- AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
253-
254- Note:
255- This method generates assistant responses for the provided conversational dialogs.
256- It employs nucleus sampling to introduce controlled randomness in text generation.
257- If logprobs is True, token log probabilities are computed for each generated token.
258- """
259- if max_gen_len is None :
260- max_gen_len = self .model .params .max_seq_len - 1
261-
262- prompt_tokens = [
263- self .formatter .encode_dialog_prompt (dialog ) for dialog in dialogs
264- ]
265- generation_tokens , generation_logprobs = self .generate (
266- prompt_tokens = prompt_tokens ,
267- max_gen_len = max_gen_len ,
268- temperature = temperature ,
269- top_p = top_p ,
270- logprobs = logprobs ,
271- )
272- if logprobs :
273- return [
274- {
275- "generation" : {
276- "role" : "assistant" ,
277- "content" : self .tokenizer .decode (t ),
278- },
279- "tokens" : [self .tokenizer .decode ([x ]) for x in t ],
280- "logprobs" : logprobs_i ,
281- }
282- for t , logprobs_i in zip (generation_tokens , generation_logprobs )
283- ]
284- return [
285- {
286- "generation" : {
287- "role" : "assistant" ,
288- "content" : self .tokenizer .decode (t ),
289- },
290- }
291- for t in generation_tokens
292- ]
293-
294218
295219def build_args_parser () -> argparse .ArgumentParser :
296220 parser = argparse .ArgumentParser ()
0 commit comments