13
13
import torch ._dynamo .config
14
14
import torch ._inductor .config
15
15
16
+ def device_sync (device ):
17
+ if "cuda" in device :
18
+ torch .cuda .synchronize ()
19
+ elif "cpu" in device :
20
+ pass
21
+ else :
22
+ print (f"device={ device } is not yet suppported" )
23
+
24
+
16
25
torch ._inductor .config .coordinate_descent_tuning = True
17
26
torch ._inductor .config .triton .unique_kernel_names = True
18
27
torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
@@ -65,11 +74,12 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
65
74
next_token , next_prob = decode_one_token (
66
75
model , cur_token , input_pos , ** sampling_kwargs
67
76
)
68
- input_pos += 1
69
- new_tokens .append (next_token .clone ())
70
- callback (new_tokens [- 1 ])
71
- new_probs .append (next_prob .clone ())
72
- cur_token = next_token .view (1 , - 1 )
77
+ input_pos += 1
78
+ new_tokens .append (next_token .clone ())
79
+ callback (new_tokens [- 1 ])
80
+ new_probs .append (next_prob .clone ())
81
+ cur_token = next_token .view (1 , - 1 )
82
+
73
83
return new_tokens , new_probs
74
84
75
85
@@ -248,6 +258,7 @@ def main(
248
258
profile : Optional [Path ] = None ,
249
259
draft_checkpoint_path : Optional [Path ] = None ,
250
260
speculate_k : int = 5 ,
261
+ device = 'cuda' ,
251
262
) -> None :
252
263
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
253
264
"""
@@ -264,7 +275,7 @@ def main(
264
275
# only print on rank 0
265
276
print = lambda * args , ** kwargs : None
266
277
267
- device = 'cuda'
278
+ print ( f"Using device= { device } " )
268
279
precision = torch .bfloat16
269
280
is_speculative = draft_checkpoint_path is not None
270
281
is_chat = "chat" in str (checkpoint_path )
@@ -278,7 +289,7 @@ def main(
278
289
else :
279
290
draft_model = None
280
291
281
- torch . cuda . synchronize ()
292
+ device_sync ( device = device ) # MKG
282
293
print (f"Time to load model: { time .time () - t0 :.02f} seconds" )
283
294
284
295
tokenizer = SentencePieceProcessor (model_file = str (tokenizer_path ))
@@ -288,7 +299,7 @@ def main(
288
299
torch .manual_seed (1234 )
289
300
model_size = sum ([p .numel () * p .dtype .itemsize for p in itertools .chain (model .parameters (), model .buffers ())])
290
301
if compile :
291
- if is_speculative and use_tp :
302
+ if is_speculative and use_tp : # and ("cuda" in device):
292
303
torch ._inductor .config .triton .cudagraph_trees = False # Bug with cudagraph trees in this case
293
304
294
305
if is_speculative :
@@ -310,7 +321,7 @@ def main(
310
321
start = - 1 if compile else 0
311
322
312
323
for i in range (start , num_samples ):
313
- torch . cuda . synchronize ()
324
+ device_sync ( device = device ) # MKG
314
325
if i >= 0 and interactive :
315
326
prompt = input ("What is your prompt? " )
316
327
if is_chat :
@@ -362,7 +373,7 @@ def callback(x):
362
373
prof .export_chrome_trace (f"{ profile } _rank_{ rank } .json" )
363
374
else :
364
375
prof .export_chrome_trace (f"{ profile } .json" )
365
- torch . cuda . synchronize ()
376
+ device_sync ( device = device ) # MKG
366
377
t = time .perf_counter () - t0
367
378
368
379
if not interactive :
@@ -401,9 +412,11 @@ def callback(x):
401
412
parser .add_argument ('--profile' , type = Path , default = None , help = 'Profile path.' )
402
413
parser .add_argument ('--speculate_k' , type = int , default = 5 , help = 'Speculative execution depth.' )
403
414
parser .add_argument ('--draft_checkpoint_path' , type = Path , default = None , help = 'Draft checkpoint path.' )
415
+ parser .add_argument ('--device' , type = str , default = "cuda" , help = 'device to use' )
404
416
405
417
args = parser .parse_args ()
406
418
main (
407
419
args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
408
- args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path , args .speculate_k
420
+ args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path ,
421
+ args .speculate_k , args .device
409
422
)
0 commit comments