@@ -33,329 +33,8 @@ def device_sync(device):
33
33
wd = Path (__file__ ).parent .parent .resolve ()
34
34
sys .path .append (str (wd ))
35
35
36
- from model import Transformer , find_multiple
37
36
from tokenizer import get_tokenizer
38
-
39
-
40
- def multinomial_sample_one_no_sync (
41
- probs_sort ,
42
- ): # Does multinomial sampling without a cuda synchronization
43
- q = torch .empty_like (probs_sort ).exponential_ (1 )
44
- return torch .argmax (probs_sort / q , dim = - 1 , keepdim = True ).to (dtype = torch .int )
45
-
46
-
47
- def logits_to_probs (logits , temperature : float = 1.0 , top_k : Optional [int ] = None ):
48
- logits = logits / max (temperature , 1e-5 )
49
-
50
- if top_k is not None :
51
- v , _ = torch .topk (logits , min (top_k , logits .size (- 1 )))
52
- pivot = v .select (- 1 , - 1 ).unsqueeze (- 1 )
53
- logits = torch .where (logits < pivot , - float ("Inf" ), logits )
54
- probs = torch .nn .functional .softmax (logits , dim = - 1 )
55
- return probs
56
-
57
-
58
- def sample (logits , temperature : float = 1.0 , top_k : Optional [int ] = None ):
59
- probs = logits_to_probs (logits [0 , - 1 ], temperature , top_k )
60
- idx_next = multinomial_sample_one_no_sync (probs )
61
- return idx_next , probs
62
-
63
-
64
- def prefill (
65
- model : Transformer , x : torch .Tensor , input_pos : torch .Tensor , ** sampling_kwargs
66
- ) -> torch .Tensor :
67
- # input_pos: [B, S]
68
- causal_mask = (
69
- torch .tril (torch .ones (len (input_pos ), len (input_pos ), dtype = torch .bool ))
70
- .unsqueeze (0 )
71
- .unsqueeze (0 )
72
- .to (x .device )
73
- )
74
- logits = model (x , input_pos , mask = causal_mask )
75
- return sample (logits , ** sampling_kwargs )[0 ]
76
-
77
-
78
- def decode_one_token (
79
- model : Transformer , x : torch .Tensor , input_pos : torch .Tensor , ** sampling_kwargs
80
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
81
- # input_pos: [B, 1]
82
- assert input_pos .shape [- 1 ] == 1
83
- logits = model (x , input_pos )
84
- return sample (logits , ** sampling_kwargs )
85
-
86
-
87
- def decode_n_tokens (
88
- model : Transformer ,
89
- cur_token : torch .Tensor ,
90
- input_pos : torch .Tensor ,
91
- num_new_tokens : int ,
92
- terminator_ids : Optional [list ] = None ,
93
- callback = lambda _ : _ ,
94
- ** sampling_kwargs ,
95
- ):
96
- new_tokens , new_probs = [], []
97
- for i in range (num_new_tokens ):
98
- with torch .backends .cuda .sdp_kernel (
99
- enable_flash = False , enable_mem_efficient = False , enable_math = True
100
- ): # Actually better for Inductor to codegen attention here
101
- next_token , next_prob = decode_one_token (
102
- model , cur_token , input_pos , ** sampling_kwargs
103
- )
104
-
105
- if terminator_ids and next_token in terminator_ids :
106
- break
107
-
108
- input_pos += 1
109
- new_tokens .append (next_token .clone ())
110
- callback (new_tokens [- 1 ])
111
- new_probs .append (next_prob .clone ())
112
- cur_token = next_token .view (1 , - 1 )
113
-
114
- return new_tokens , new_probs
115
-
116
-
117
- def model_forward (model , x , input_pos ):
118
- return model (x , input_pos )
119
-
120
-
121
- def speculative_decode (
122
- model : Transformer ,
123
- draft_model : Transformer ,
124
- cur_token : torch .Tensor ,
125
- input_pos : int ,
126
- speculate_k : int ,
127
- ** sampling_kwargs ,
128
- ) -> torch .Tensor :
129
- # draft model inference sequentially
130
- device = cur_token .device
131
- orig_input_pos = torch .tensor (
132
- [input_pos ], dtype = torch .int64 , device = cur_token .device
133
- )
134
- draft_tokens , draft_probs = decode_n_tokens (
135
- draft_model ,
136
- cur_token .view (1 , - 1 ),
137
- orig_input_pos .clone (),
138
- speculate_k ,
139
- ** sampling_kwargs ,
140
- )
141
-
142
- draft_tokens = torch .cat (draft_tokens )
143
- # parallel inference on target model using draft tokens
144
- target_logits = model_forward (
145
- model ,
146
- torch .cat ([cur_token .view (1 ), draft_tokens ]).view (1 , - 1 ),
147
- torch .arange (input_pos , input_pos + speculate_k + 1 , device = cur_token .device ),
148
- )
149
- target_probs = logits_to_probs (target_logits [0 ], ** sampling_kwargs )
150
- draft_probs = torch .stack (draft_probs )
151
- # q: target prob, p: draft prob
152
- # q >= p: always accept draft token
153
- # q < p: q/p prob to accept draft token
154
- p = draft_probs [torch .arange (0 , speculate_k , device = device ), draft_tokens ]
155
- q = target_probs [torch .arange (0 , speculate_k , device = device ), draft_tokens ]
156
- accept_draft_prob = torch .minimum (torch .ones (()), q [:speculate_k ] / p )
157
- rejected_locations = (
158
- torch .rand_like (accept_draft_prob ) > accept_draft_prob
159
- ).nonzero ()
160
-
161
- if rejected_locations .shape [0 ] == 0 : # All draft tokens have been accepted
162
- accept_length = speculate_k + 1
163
- last_token = multinomial_sample_one_no_sync (target_probs [- 1 ])
164
- # fill last token into draft model
165
- model_forward (
166
- draft_model ,
167
- draft_tokens [- 1 ].view (1 , - 1 ),
168
- orig_input_pos + speculate_k ,
169
- )
170
- return torch .cat ([draft_tokens , last_token ])
171
- else :
172
- accept_length = rejected_locations [0 ].item ()
173
- p = draft_probs [accept_length ]
174
- q = target_probs [accept_length ]
175
- new = q - p
176
- new = torch .where (new > 0 , new , 0.0 )
177
- new = new / new .sum ()
178
- next_token = multinomial_sample_one_no_sync (new )
179
- return torch .cat ([draft_tokens [:accept_length ], next_token ])
180
-
181
-
182
- def normalize_cache_length (
183
- max_cache_length : float , max_seq_length : int , multiple_of : int = 8
184
- ) -> int :
185
- """
186
- Computes the absolute cache length given the max_cache_length and max_seq_length.
187
- """
188
- if 0 < max_cache_length <= 1 :
189
- max_cache_length = round (max_seq_length * max_cache_length )
190
- else :
191
- assert int (max_cache_length ) == max_cache_length
192
- max_cache_length = int (max_cache_length )
193
- if max_cache_length > max_seq_length :
194
- print (
195
- f"Warning: max_cache_length ({ max_cache_length } ) is greater than max_seq_length ({ max_seq_length } ). Setting to { max_seq_length } "
196
- )
197
- max_cache_length = max_seq_length
198
- return min (find_multiple (max_cache_length , multiple_of ), max_seq_length )
199
-
200
-
201
- @torch .no_grad ()
202
- def generate (
203
- model : Transformer ,
204
- prompt : torch .Tensor ,
205
- max_new_tokens : int ,
206
- * ,
207
- interactive : bool ,
208
- draft_model : Transformer ,
209
- speculate_k : Optional [int ] = 8 ,
210
- callback = lambda x : x ,
211
- terminator_ids : Optional [list ] = None ,
212
- cache_kwargs : dict = None ,
213
- ** sampling_kwargs ,
214
- ) -> torch .Tensor :
215
- """
216
- Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
217
- """
218
-
219
- is_speculative = draft_model is not None
220
- # create an empty tensor of the expected final shape and fill in the current tokens
221
- T = prompt .size (0 )
222
- max_seq_length = min (T + max_new_tokens , model .config .block_size )
223
- if interactive :
224
- max_seq_length = 350
225
- print (f"Maximum context length of { max_seq_length } tokens." )
226
-
227
- max_new_tokens = max_seq_length - T
228
-
229
- device , dtype = prompt .device , prompt .dtype
230
- max_seq_length = (
231
- max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
232
- )
233
-
234
- # Normalize max_cache_length to absolute cache length if provided as a fraction of the max seq sequence length
235
- cache_kwargs ["max_cache_length" ] = list (
236
- map (
237
- lambda l : normalize_cache_length (l , max_seq_length ),
238
- cache_kwargs ["max_cache_length" ],
239
- )
240
- )
241
- assert (
242
- model .config .n_layer % len (cache_kwargs ["max_cache_length" ]) == 0
243
- ), f'max_cache_length ({ len (cache_kwargs ["max_cache_length" ])} ) must be a factor of { model .config .n_layer } layers.'
244
-
245
- tile_size = model .config .n_layer // len (cache_kwargs ["max_cache_length" ])
246
- cache_kwargs ["max_cache_length" ] = [
247
- item for item in cache_kwargs ["max_cache_length" ] for _ in range (tile_size )
248
- ]
249
-
250
- # Gets called twice when model is wrapped in torch.compile which causes an error without the if statement
251
- if type (cache_kwargs ["drop_amount" ]) != list :
252
- cache_kwargs ["drop_amount" ] = [
253
- max (int (cache_kwargs ["drop_amount" ] * l ), 1 )
254
- for l in cache_kwargs ["max_cache_length" ]
255
- ]
256
-
257
- assert cache_kwargs ["global_tokens" ] <= min (
258
- cache_kwargs ["max_cache_length" ]
259
- ), "Global tokens must be less than max_cache_length."
260
-
261
- with torch .device (device ):
262
- model .setup_caches (max_batch_size = 1 , ** cache_kwargs )
263
- if is_speculative and draft_model is not model :
264
- draft_model .setup_caches (max_batch_size = 1 , ** cache_kwargs )
265
-
266
- # create an empty tensor (all -1) of the expected final shape and fill in the current tokens
267
- # GPT-Fast had this as empty but the values of empty are non-deterministic
268
- seq = torch .full ((max_seq_length ,), - 1 , dtype = dtype , device = device )
269
- seq [:T ] = prompt
270
- input_pos = torch .arange (0 , T , device = device )
271
-
272
- next_token = prefill (
273
- model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs
274
- ).clone ()
275
- if is_speculative :
276
- prefill (draft_model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs )
277
- seq [T ] = next_token
278
-
279
- input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
280
- accept_counts = [0 ] * (speculate_k + 1 )
281
-
282
- if is_speculative :
283
- input_pos = input_pos .item () # for speculative decoding easier to keep on host
284
- while input_pos < max_seq_length - 1 :
285
- cur_token = next_token .view (())
286
-
287
- next_tokens = speculative_decode (
288
- model , draft_model , cur_token , input_pos , speculate_k , ** sampling_kwargs
289
- )
290
-
291
- accept_counts [len (next_tokens ) - 1 ] += 1
292
- num_added = min (max_seq_length - input_pos - 1 , len (next_tokens ))
293
- seq [input_pos + 1 : input_pos + num_added + 1 ] = next_tokens [:num_added ]
294
- for i in next_tokens [:num_added ,]:
295
- callback (i )
296
- input_pos = input_pos + num_added
297
- next_token = next_tokens [- 1 ]
298
- else :
299
- generated_tokens , _ = decode_n_tokens (
300
- model ,
301
- next_token .view (1 , - 1 ),
302
- input_pos ,
303
- max_new_tokens - 1 ,
304
- callback = callback ,
305
- terminator_ids = terminator_ids ,
306
- ** sampling_kwargs ,
307
- )
308
- if len (generated_tokens ) > 0 :
309
- seq [T + 1 : T + 1 + len (generated_tokens )] = torch .cat (generated_tokens )
310
-
311
- # Truncate seq to first instance of -1 if -1 is present
312
- if - 1 in seq :
313
- seq = seq [: torch .where (seq == - 1 )[0 ][0 ]]
314
-
315
- generate_stats = {"accept_counts" : accept_counts }
316
- return seq , generate_stats
317
-
318
-
319
- def encode_tokens (tokenizer , string , bos = True , device = default_device ):
320
- tokens = tokenizer .encode (string )
321
- if bos :
322
- tokens = [tokenizer .bos_id ()] + tokens
323
- return torch .tensor (tokens , dtype = torch .int , device = device )
324
-
325
-
326
- def _load_model (checkpoint_path , device , precision , use_tp ):
327
- use_cuda = "cuda" in device
328
- with torch .device ("meta" ):
329
- model = Transformer .from_name (checkpoint_path .parent .name )
330
-
331
- if "int8" in str (checkpoint_path ):
332
- print ("Using int8 weight-only quantization!" )
333
- from quantize import WeightOnlyInt8QuantHandler
334
-
335
- simple_quantizer = WeightOnlyInt8QuantHandler (model )
336
- model = simple_quantizer .convert_for_runtime ()
337
-
338
- if "int4" in str (checkpoint_path ):
339
- print ("Using int4 weight-only quantization!" )
340
- path_comps = checkpoint_path .name .split ("." )
341
- groupsize = int (path_comps [- 2 ][1 :])
342
- from quantize import WeightOnlyInt4QuantHandler
343
-
344
- simple_quantizer = WeightOnlyInt4QuantHandler (model , groupsize )
345
- model = simple_quantizer .convert_for_runtime ()
346
-
347
- checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
348
- if "model" in checkpoint and "stories" in str (checkpoint_path ):
349
- checkpoint = checkpoint ["model" ]
350
- model .load_state_dict (checkpoint , assign = True )
351
- if use_tp :
352
- from tp import apply_tp
353
-
354
- print ("Applying tensor parallel to model ..." )
355
- apply_tp (model )
356
-
357
- model = model .to (device = device , dtype = precision )
358
- return model .eval ()
37
+ from generation_utils import generate , encode_tokens , _load_model
359
38
360
39
361
40
def _get_model_size (model ):
@@ -513,7 +192,7 @@ def callback(x):
513
192
torch .profiler ._utils ._init_for_cuda_graphs ()
514
193
prof = torch .profiler .profile ()
515
194
with prof :
516
- y , metrics = generate (
195
+ y , metrics , _ = generate (
517
196
model ,
518
197
encoded ,
519
198
max_new_tokens ,
0 commit comments