Skip to content

Commit ce8c6be

Browse files
authored
Merge pull request #37 from mikekgfb/main
Support code gen for non-cuda targets with gpt-fast
2 parents 8c8c463 + fcf7fc8 commit ce8c6be

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

generate.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@
1313
import torch._dynamo.config
1414
import torch._inductor.config
1515

16+
def device_sync(device):
17+
if "cuda" in device:
18+
torch.cuda.synchronize()
19+
elif "cpu" in device:
20+
pass
21+
else:
22+
print(f"device={device} is not yet suppported")
23+
24+
1625
torch._inductor.config.coordinate_descent_tuning = True
1726
torch._inductor.config.triton.unique_kernel_names = True
1827
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
@@ -65,11 +74,12 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
6574
next_token, next_prob = decode_one_token(
6675
model, cur_token, input_pos, **sampling_kwargs
6776
)
68-
input_pos += 1
69-
new_tokens.append(next_token.clone())
70-
callback(new_tokens[-1])
71-
new_probs.append(next_prob.clone())
72-
cur_token = next_token.view(1, -1)
77+
input_pos += 1
78+
new_tokens.append(next_token.clone())
79+
callback(new_tokens[-1])
80+
new_probs.append(next_prob.clone())
81+
cur_token = next_token.view(1, -1)
82+
7383
return new_tokens, new_probs
7484

7585

@@ -248,6 +258,7 @@ def main(
248258
profile: Optional[Path] = None,
249259
draft_checkpoint_path: Optional[Path] = None,
250260
speculate_k: int = 5,
261+
device='cuda',
251262
) -> None:
252263
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
253264
"""
@@ -264,7 +275,7 @@ def main(
264275
# only print on rank 0
265276
print = lambda *args, **kwargs: None
266277

267-
device = 'cuda'
278+
print(f"Using device={device}")
268279
precision = torch.bfloat16
269280
is_speculative = draft_checkpoint_path is not None
270281
is_chat = "chat" in str(checkpoint_path)
@@ -278,7 +289,7 @@ def main(
278289
else:
279290
draft_model = None
280291

281-
torch.cuda.synchronize()
292+
device_sync(device=device) # MKG
282293
print(f"Time to load model: {time.time() - t0:.02f} seconds")
283294

284295
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
@@ -288,7 +299,7 @@ def main(
288299
torch.manual_seed(1234)
289300
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
290301
if compile:
291-
if is_speculative and use_tp:
302+
if is_speculative and use_tp: # and ("cuda" in device):
292303
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
293304

294305
if is_speculative:
@@ -310,7 +321,7 @@ def main(
310321
start = -1 if compile else 0
311322

312323
for i in range(start, num_samples):
313-
torch.cuda.synchronize()
324+
device_sync(device=device) # MKG
314325
if i >= 0 and interactive:
315326
prompt = input("What is your prompt? ")
316327
if is_chat:
@@ -362,7 +373,7 @@ def callback(x):
362373
prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
363374
else:
364375
prof.export_chrome_trace(f"{profile}.json")
365-
torch.cuda.synchronize()
376+
device_sync(device=device) # MKG
366377
t = time.perf_counter() - t0
367378

368379
if not interactive:
@@ -401,9 +412,11 @@ def callback(x):
401412
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
402413
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
403414
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
415+
parser.add_argument('--device', type=str, default="cuda", help='device to use')
404416

405417
args = parser.parse_args()
406418
main(
407419
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
408-
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, args.speculate_k
420+
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
421+
args.speculate_k, args.device
409422
)

0 commit comments

Comments
 (0)