@@ -31,6 +31,7 @@ def generate(
31
31
prompt : torch .Tensor ,
32
32
max_returned_tokens : int ,
33
33
* ,
34
+ prompt_chunksize : int = 1 ,
34
35
temperature : float = 1.0 ,
35
36
top_k : Optional [int ] = None ,
36
37
top_p : float = 1.0 ,
@@ -60,35 +61,60 @@ def generate(
60
61
or https://huyenchip.com/2024/01/16/sampling.html#top_p
61
62
stop_tokens: If specified, stop generating any more token once one of this list is generated.
62
63
"""
63
- from litgpt .generate .base import generate_fn
64
- return generate_fn (
65
- include_prompt = False ,
66
- include_eos = False ,
67
- model = model ,
68
- prompt = prompt ,
69
- max_returned_tokens = max_returned_tokens ,
70
- temperature = temperature ,
71
- top_k = top_k ,
72
- top_p = top_p ,
73
- stop_tokens = stop_tokens
64
+ from litgpt .generate .base import batched_generate_fn
65
+
66
+ return map (
67
+ lambda lst : lst [0 ],
68
+ batched_generate_fn (
69
+ model = model ,
70
+ prompts = [prompt ],
71
+ max_returned_tokens = max_returned_tokens ,
72
+ prompt_chunksize = prompt_chunksize ,
73
+ sample_args = dict (
74
+ temperature = temperature ,
75
+ top_k = top_k ,
76
+ top_p = top_p ,
77
+ ),
78
+ stop_tokens = stop_tokens ,
79
+ include_prompt = False ,
80
+ include_eos = False ,
81
+ )
74
82
)
75
83
76
84
77
- def process_prompt (prompt , model , tokenizer , prompt_style , fabric , temperature , max_new_tokens , top_k , top_p , stop_tokens ):
85
+ def process_prompt (
86
+ prompt : str ,
87
+ model : GPT ,
88
+ tokenizer ,
89
+ prompt_style ,
90
+ fabric ,
91
+ max_new_tokens : int ,
92
+ prompt_chunksize : int ,
93
+ temperature : float ,
94
+ top_k : Optional [int ],
95
+ top_p : float ,
96
+ stop_tokens : Tuple [List [int ], ...],
97
+ ):
78
98
prompt = prompt_style .apply (prompt = prompt )
79
99
encoded_prompt = tokenizer .encode (prompt , device = fabric .device )
80
100
81
101
if max_new_tokens is None :
82
102
max_returned_tokens = model .max_seq_length
83
103
else :
84
- first_turn = model .mask_cache is None
85
104
max_returned_tokens = encoded_prompt .size (0 ) + max_new_tokens
86
- if first_turn or max_returned_tokens > model .max_seq_length :
105
+ msl = model .max_seq_length
106
+ if max_returned_tokens > msl or model .config .block_size == msl :
87
107
model .max_seq_length = max_returned_tokens
88
- model .set_kv_cache (batch_size = 1 , device = fabric .device )
89
108
90
109
y : Iterator [torch .Tensor ] = generate (
91
- model , encoded_prompt , max_returned_tokens , temperature = temperature , top_k = top_k , top_p = top_p , stop_tokens = stop_tokens
110
+ model = model ,
111
+ prompt = encoded_prompt ,
112
+ max_returned_tokens = max_returned_tokens ,
113
+ prompt_chunksize = prompt_chunksize ,
114
+ temperature = temperature ,
115
+ top_k = top_k ,
116
+ top_p = top_p ,
117
+ stop_tokens = stop_tokens ,
92
118
)
93
119
token_generator : Iterator [str ] = tokenizer .decode_stream (y , device = fabric .device )
94
120
@@ -103,8 +129,7 @@ def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature,
103
129
104
130
t = time .perf_counter () - t0
105
131
106
- for block in model .transformer .h :
107
- block .attn .kv_cache .reset_parameters ()
132
+ model .clear_kv_cache ()
108
133
fabric .print (
109
134
f"\n Time for inference: { t :.02f} sec total, { tokens_generated / t :.02f} tokens/sec,"
110
135
f" { tokens_generated } tokens" ,
@@ -113,7 +138,19 @@ def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature,
113
138
fabric .print ()
114
139
115
140
116
- def interact (multiline , model , tokenizer , prompt_style , fabric , temperature , max_new_tokens , top_k , top_p , stop_tokens ):
141
+ def interact (
142
+ multiline : bool ,
143
+ model : GPT ,
144
+ tokenizer ,
145
+ prompt_style ,
146
+ fabric ,
147
+ max_new_tokens : int ,
148
+ prompt_chunksize : int ,
149
+ temperature : float ,
150
+ top_k : Optional [int ],
151
+ top_p : float ,
152
+ stop_tokens : Tuple [List [int ], ...],
153
+ ):
117
154
while True :
118
155
try :
119
156
if not multiline :
@@ -135,14 +172,27 @@ def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max
135
172
if not prompt or prompt in ("!quit" , "!exit" ):
136
173
break
137
174
138
- process_prompt (prompt , model , tokenizer , prompt_style , fabric , temperature , max_new_tokens , top_k , top_p , stop_tokens )
175
+ process_prompt (
176
+ prompt = prompt ,
177
+ model = model ,
178
+ tokenizer = tokenizer ,
179
+ prompt_style = prompt_style ,
180
+ fabric = fabric ,
181
+ temperature = temperature ,
182
+ max_new_tokens = max_new_tokens ,
183
+ prompt_chunksize = prompt_chunksize ,
184
+ top_k = top_k ,
185
+ top_p = top_p ,
186
+ stop_tokens = stop_tokens ,
187
+ )
139
188
140
189
141
190
@torch .inference_mode ()
142
191
def main (
143
192
checkpoint_dir : Path ,
144
193
* ,
145
194
max_new_tokens : int = 50 ,
195
+ prompt_chunksize : int = 1 ,
146
196
top_k : Optional [int ] = 50 ,
147
197
top_p : float = 1.0 ,
148
198
temperature : float = 0.8 ,
@@ -158,6 +208,11 @@ def main(
158
208
checkpoint_dir: A local path to a directory containing the model weights or a valid model name.
159
209
You can get a list of valid model names via the `litgpt download list` command line argument.
160
210
max_new_tokens: The number of generation steps to take.
211
+ prompt_chunksize: If even the shortest prompt is longer than the KV
212
+ cache, prompts are processed in chunks of this size in the
213
+ prefill phase. Once the shortest has been processed to the
214
+ end, we proceed with chunk size 1.
215
+ Defaults to 1, but larger values are recommended for long prompts.
161
216
top_k: The number of top most probable tokens to consider in the sampling process.
162
217
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
163
218
In top-p sampling, the next token is sampled from the highest probability tokens
@@ -252,8 +307,9 @@ def main(
252
307
tokenizer = tokenizer ,
253
308
prompt_style = prompt_style ,
254
309
fabric = fabric ,
255
- temperature = temperature ,
256
310
max_new_tokens = (None if compile else max_new_tokens ),
311
+ prompt_chunksize = prompt_chunksize ,
312
+ temperature = temperature ,
257
313
top_k = top_k ,
258
314
top_p = top_p ,
259
315
stop_tokens = stop_tokens
0 commit comments