99
1010import torch
1111
12- from executorch .examples .models .llama .llama_transformer import ModelArgs
1312from executorch .extension .llm .tokenizer .utils import get_tokenizer
1413
1514
@@ -51,11 +50,35 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
5150
5251
5352class LlamaRunner (ABC ):
54- def __init__ (self , tokenizer_path : str , model_args : ModelArgs , device : str = "cpu" ):
55- self .params = model_args
53+ def __init__ (
54+ self ,
55+ tokenizer_path : str ,
56+ max_seq_len : int ,
57+ max_batch_size : int ,
58+ use_kv_cache : bool ,
59+ vocab_size : int ,
60+ has_full_logits : bool = False ,
61+ device : str = "cpu" ,
62+ ):
63+ """
64+ Constructor.
65+
66+ Args:
67+ tokenizer_path: path to tokenizer.model file.
68+ max_seq_len: max length of the output sequence, after which the output will be clipped.
69+ max_batch_size: max batch size.
70+ use_kv_cache: whether to use a KV cache.
71+ vocab_size: number of items in the vocab.
72+ has_full_logits: whether the model returns the full logits or only returns the last logit.
73+ device: device to run the runner on.
74+ """
75+ self .max_seq_len = max_seq_len
76+ self .max_batch_size = max_batch_size
77+ self .use_kv_cache = use_kv_cache
5678 self .tokenizer = get_tokenizer (tokenizer_path )
57- assert model_args . vocab_size == self . tokenizer . n_words
79+ self . has_full_logits = has_full_logits
5880 self .device = device
81+ assert vocab_size == self .tokenizer .n_words
5982
6083 @abstractmethod
6184 def forward (
@@ -77,16 +100,22 @@ def generate( # noqa: C901
77100 tokens = torch .tensor ([prompt_tokens ], dtype = torch .long , device = self .device ),
78101 input_pos = (
79102 torch .tensor ([0 ], dtype = torch .long , device = self .device )
80- if self .params . use_kv_cache
103+ if self .use_kv_cache
81104 else None
82105 ),
83106 )
84107
85- current_token = next_token (logits , temperature , top_p )
108+ current_token = next_token (logits [:, - 1 , :], temperature , top_p )
109+ if self .has_full_logits :
110+ current_token = next_token (logits [:, - 1 , :], temperature , top_p )
111+ else :
112+ current_token = next_token (logits , temperature , top_p )
86113 tokens = prompt_tokens + [current_token ]
87114
88- while len (tokens ) < self .params .max_seq_len :
89- if self .params .use_kv_cache :
115+ i = 0
116+ while len (tokens ) < self .max_seq_len :
117+ print (f"{ i } out of { self .max_seq_len } max tokens generated" )
118+ if self .use_kv_cache :
90119 logits = self .forward (
91120 tokens = torch .tensor (
92121 [[current_token ]], dtype = torch .long , device = self .device
@@ -99,13 +128,21 @@ def generate( # noqa: C901
99128 logits = self .forward (
100129 tokens = torch .tensor ([tokens ], dtype = torch .long , device = self .device ),
101130 )
102- current_token = next_token (logits , temperature , top_p )
131+
132+ # If the logits aren't already clipped to only contain the last logit, clip them.
133+ if self .has_full_logits :
134+ current_token = next_token (logits [:, - 1 , :], temperature , top_p )
135+ else :
136+ current_token = next_token (logits , temperature , top_p )
137+
103138 if current_token == self .tokenizer .eos_id or (
104139 hasattr (self .tokenizer , "stop_tokens" )
105140 and current_token in self .tokenizer .stop_tokens
106141 ):
107142 break
143+
108144 tokens .append (current_token )
145+ i += 1
109146
110147 return tokens if echo else tokens [len (prompt_tokens ) :]
111148
0 commit comments