13
13
import torch ._dynamo .config
14
14
import torch ._inductor .config
15
15
16
+ def device_sync (device ):
17
+ if "cuda" in device :
18
+ torch .cuda .synchronize ()
19
+ elif "cpu" in device :
20
+ pass
21
+ else :
22
+ print (f"device={ device } is not yet suppported" )
23
+
24
+
16
25
torch ._inductor .config .coordinate_descent_tuning = True
17
26
torch ._inductor .config .triton .unique_kernel_names = True
18
- torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
27
+ # torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
19
28
20
29
21
30
# support running without installing as a package
@@ -58,18 +67,30 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
58
67
logits = model (x , input_pos )
59
68
return sample (logits , ** sampling_kwargs )
60
69
61
- def decode_n_tokens (model : Transformer , cur_token : torch .Tensor , input_pos : torch .Tensor , num_new_tokens : int , callback = lambda _ : _ , ** sampling_kwargs ):
70
+ def decode_n_tokens (model : Transformer , cur_token : torch .Tensor , input_pos : torch .Tensor , num_new_tokens : int , use_sdpa = False , callback = lambda _ : _ , ** sampling_kwargs ):
62
71
new_tokens , new_probs = [], []
63
- for i in range (num_new_tokens ):
64
- with torch .backends .cuda .sdp_kernel (enable_flash = False , enable_mem_efficient = False , enable_math = True ): # Actually better for Inductor to codegen attention here
72
+ if not use_sdpa :
73
+ for i in range (num_new_tokens ):
74
+ with torch .backends .cuda .sdp_kernel (enable_flash = False , enable_mem_efficient = False , enable_math = True ): # Actually better for Inductor to codegen attention here
75
+ next_token , next_prob = decode_one_token (
76
+ model , cur_token , input_pos , ** sampling_kwargs
77
+ )
78
+ input_pos += 1
79
+ new_tokens .append (next_token .clone ())
80
+ callback (new_tokens [- 1 ])
81
+ new_probs .append (next_prob .clone ())
82
+ cur_token = next_token .view (1 , - 1 )
83
+ else :
84
+ for i in range (num_new_tokens ):
65
85
next_token , next_prob = decode_one_token (
66
86
model , cur_token , input_pos , ** sampling_kwargs
67
87
)
68
- input_pos += 1
69
- new_tokens .append (next_token .clone ())
70
- callback (new_tokens [- 1 ])
71
- new_probs .append (next_prob .clone ())
72
- cur_token = next_token .view (1 , - 1 )
88
+ input_pos += 1
89
+ new_tokens .append (next_token .clone ())
90
+ callback (new_tokens [- 1 ])
91
+ new_probs .append (next_prob .clone ())
92
+ cur_token = next_token .view (1 , - 1 )
93
+
73
94
return new_tokens , new_probs
74
95
75
96
@@ -82,12 +103,13 @@ def speculative_decode(
82
103
cur_token : torch .Tensor ,
83
104
input_pos : int ,
84
105
speculate_k : int ,
106
+ use_sdpa = False ,
85
107
** sampling_kwargs
86
108
) -> torch .Tensor :
87
109
# draft model inference sequentially
88
110
device = cur_token .device
89
111
orig_input_pos = torch .tensor ([input_pos ], dtype = torch .int64 , device = cur_token .device )
90
- draft_tokens , draft_probs = decode_n_tokens (draft_model , cur_token .view (1 , - 1 ), orig_input_pos .clone (), speculate_k , ** sampling_kwargs )
112
+ draft_tokens , draft_probs = decode_n_tokens (draft_model , cur_token .view (1 , - 1 ), orig_input_pos .clone (), speculate_k , use_sdpa = use_sdpa , ** sampling_kwargs )
91
113
92
114
draft_tokens = torch .cat (draft_tokens )
93
115
# parallel inference on target model using draft tokens
@@ -135,6 +157,7 @@ def generate(
135
157
interactive : bool ,
136
158
draft_model : Transformer ,
137
159
speculate_k : Optional [int ] = 8 ,
160
+ use_sdpa = False ,
138
161
callback = lambda x : x ,
139
162
** sampling_kwargs
140
163
) -> torch .Tensor :
@@ -178,7 +201,7 @@ def generate(
178
201
cur_token = next_token .view (())
179
202
180
203
next_tokens = speculative_decode (
181
- model , draft_model , cur_token , input_pos , speculate_k , ** sampling_kwargs
204
+ model , draft_model , cur_token , input_pos , speculate_k , use_sdpa = use_sdpa , ** sampling_kwargs
182
205
)
183
206
184
207
accept_counts [len (next_tokens ) - 1 ] += 1
@@ -189,7 +212,7 @@ def generate(
189
212
input_pos = input_pos + num_added
190
213
next_token = next_tokens [- 1 ]
191
214
else :
192
- generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
215
+ generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , use_sdpa = use_sdpa , callback = callback , ** sampling_kwargs )
193
216
seq [T + 1 :] = torch .cat (generated_tokens )
194
217
195
218
generate_stats = {
@@ -248,6 +271,8 @@ def main(
248
271
profile : Optional [Path ] = None ,
249
272
draft_checkpoint_path : Optional [Path ] = None ,
250
273
speculate_k : int = 5 ,
274
+ use_sdpa = False ,
275
+ device = 'cuda' ,
251
276
) -> None :
252
277
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
253
278
"""
@@ -265,7 +290,7 @@ def main(
265
290
# only print on rank 0
266
291
print = lambda * args , ** kwargs : None
267
292
268
- device = 'cuda'
293
+ print ( f"Using device= { device } " )
269
294
precision = torch .bfloat16
270
295
is_speculative = draft_checkpoint_path is not None
271
296
is_chat = "chat" in str (checkpoint_path )
@@ -279,7 +304,7 @@ def main(
279
304
else :
280
305
draft_model = None
281
306
282
- torch . cuda . synchronize ()
307
+ device_sync ( device = device ) # MKG
283
308
print (f"Time to load model: { time .time () - t0 :.02f} seconds" )
284
309
285
310
tokenizer = SentencePieceProcessor (model_file = str (tokenizer_path ))
@@ -289,8 +314,9 @@ def main(
289
314
torch .manual_seed (1234 )
290
315
model_size = sum ([p .numel () * p .dtype .itemsize for p in itertools .chain (model .parameters (), model .buffers ())])
291
316
if compile :
292
- if is_speculative and use_tp :
293
- torch ._inductor .config .triton .cudagraph_trees = False # Bug with cudagraph trees in this case
317
+ # MKG
318
+ # if is_speculative and use_tp:
319
+ # torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
294
320
295
321
if is_speculative :
296
322
global model_forward , logits_to_prob
@@ -311,7 +337,7 @@ def main(
311
337
start = - 1 if compile else 0
312
338
313
339
for i in range (start , num_samples ):
314
- torch . cuda . synchronize ()
340
+ device_sync ( device = device ) # MKG
315
341
if i >= 0 and interactive :
316
342
prompt = input ("What is your prompt? " )
317
343
if is_chat :
@@ -349,6 +375,7 @@ def callback(x):
349
375
max_new_tokens ,
350
376
draft_model = draft_model ,
351
377
speculate_k = speculate_k ,
378
+ use_sdpa = use_sdpa ,
352
379
interactive = interactive ,
353
380
callback = callback ,
354
381
temperature = temperature ,
@@ -363,7 +390,7 @@ def callback(x):
363
390
prof .export_chrome_trace (f"{ profile } _rank_{ rank } .json" )
364
391
else :
365
392
prof .export_chrome_trace (f"{ profile } .json" )
366
- torch . cuda . synchronize ()
393
+ device_sync ( device = device ) # MKG
367
394
t = time .perf_counter () - t0
368
395
369
396
if not interactive :
@@ -402,9 +429,12 @@ def callback(x):
402
429
parser .add_argument ('--profile' , type = Path , default = None , help = 'Profile path.' )
403
430
parser .add_argument ('--speculate_k' , type = int , default = 5 , help = 'Speculative execution depth.' )
404
431
parser .add_argument ('--draft_checkpoint_path' , type = Path , default = None , help = 'Draft checkpoint path.' )
432
+ parser .add_argument ('--use_sdpa' , action = 'store_true' , help = 'Whether to use SDPA' )
433
+ parser .add_argument ('--device' , type = str , default = "cuda" , help = 'device to use' )
405
434
406
435
args = parser .parse_args ()
407
436
main (
408
437
args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
409
- args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path , args .speculate_k
438
+ args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path ,
439
+ args .speculate_k , args .use_sdpa , args .device
410
440
)
0 commit comments