Skip to content

Commit 029db46

Browse files
authored
Add files via upload
1 parent db7b273 commit 029db46

File tree

1 file changed

+49
-19
lines changed

1 file changed

+49
-19
lines changed

generate.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,18 @@
1313
import torch._dynamo.config
1414
import torch._inductor.config
1515

16+
def device_sync(device):
17+
if "cuda" in device:
18+
torch.cuda.synchronize()
19+
elif "cpu" in device:
20+
pass
21+
else:
22+
print(f"device={device} is not yet suppported")
23+
24+
1625
torch._inductor.config.coordinate_descent_tuning = True
1726
torch._inductor.config.triton.unique_kernel_names = True
18-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
27+
# torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
1928

2029

2130
# support running without installing as a package
@@ -58,18 +67,30 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
5867
logits = model(x, input_pos)
5968
return sample(logits, **sampling_kwargs)
6069

61-
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
70+
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, use_sdpa=False, callback=lambda _: _, **sampling_kwargs):
6271
new_tokens, new_probs = [], []
63-
for i in range(num_new_tokens):
64-
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
72+
if not use_sdpa:
73+
for i in range(num_new_tokens):
74+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
75+
next_token, next_prob = decode_one_token(
76+
model, cur_token, input_pos, **sampling_kwargs
77+
)
78+
input_pos += 1
79+
new_tokens.append(next_token.clone())
80+
callback(new_tokens[-1])
81+
new_probs.append(next_prob.clone())
82+
cur_token = next_token.view(1, -1)
83+
else:
84+
for i in range(num_new_tokens):
6585
next_token, next_prob = decode_one_token(
6686
model, cur_token, input_pos, **sampling_kwargs
6787
)
68-
input_pos += 1
69-
new_tokens.append(next_token.clone())
70-
callback(new_tokens[-1])
71-
new_probs.append(next_prob.clone())
72-
cur_token = next_token.view(1, -1)
88+
input_pos += 1
89+
new_tokens.append(next_token.clone())
90+
callback(new_tokens[-1])
91+
new_probs.append(next_prob.clone())
92+
cur_token = next_token.view(1, -1)
93+
7394
return new_tokens, new_probs
7495

7596

@@ -82,12 +103,13 @@ def speculative_decode(
82103
cur_token: torch.Tensor,
83104
input_pos: int,
84105
speculate_k: int,
106+
use_sdpa=False,
85107
**sampling_kwargs
86108
) -> torch.Tensor:
87109
# draft model inference sequentially
88110
device = cur_token.device
89111
orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
90-
draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
112+
draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, use_sdpa=use_sdpa, **sampling_kwargs)
91113

92114
draft_tokens = torch.cat(draft_tokens)
93115
# parallel inference on target model using draft tokens
@@ -135,6 +157,7 @@ def generate(
135157
interactive: bool,
136158
draft_model: Transformer,
137159
speculate_k: Optional[int] = 8,
160+
use_sdpa=False,
138161
callback = lambda x: x,
139162
**sampling_kwargs
140163
) -> torch.Tensor:
@@ -178,7 +201,7 @@ def generate(
178201
cur_token = next_token.view(())
179202

180203
next_tokens = speculative_decode(
181-
model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
204+
model, draft_model, cur_token, input_pos, speculate_k, use_sdpa=use_sdpa, **sampling_kwargs
182205
)
183206

184207
accept_counts[len(next_tokens) - 1] += 1
@@ -189,7 +212,7 @@ def generate(
189212
input_pos = input_pos + num_added
190213
next_token = next_tokens[-1]
191214
else:
192-
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
215+
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, use_sdpa=use_sdpa, callback=callback, **sampling_kwargs)
193216
seq[T + 1:] = torch.cat(generated_tokens)
194217

195218
generate_stats = {
@@ -248,6 +271,8 @@ def main(
248271
profile: Optional[Path] = None,
249272
draft_checkpoint_path: Optional[Path] = None,
250273
speculate_k: int = 5,
274+
use_sdpa=False,
275+
device='cuda',
251276
) -> None:
252277
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
253278
"""
@@ -265,7 +290,7 @@ def main(
265290
# only print on rank 0
266291
print = lambda *args, **kwargs: None
267292

268-
device = 'cuda'
293+
print(f"Using device={device}")
269294
precision = torch.bfloat16
270295
is_speculative = draft_checkpoint_path is not None
271296
is_chat = "chat" in str(checkpoint_path)
@@ -279,7 +304,7 @@ def main(
279304
else:
280305
draft_model = None
281306

282-
torch.cuda.synchronize()
307+
device_sync(device=device) # MKG
283308
print(f"Time to load model: {time.time() - t0:.02f} seconds")
284309

285310
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
@@ -289,8 +314,9 @@ def main(
289314
torch.manual_seed(1234)
290315
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
291316
if compile:
292-
if is_speculative and use_tp:
293-
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
317+
# MKG
318+
# if is_speculative and use_tp:
319+
# torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
294320

295321
if is_speculative:
296322
global model_forward, logits_to_prob
@@ -311,7 +337,7 @@ def main(
311337
start = -1 if compile else 0
312338

313339
for i in range(start, num_samples):
314-
torch.cuda.synchronize()
340+
device_sync(device=device) # MKG
315341
if i >= 0 and interactive:
316342
prompt = input("What is your prompt? ")
317343
if is_chat:
@@ -349,6 +375,7 @@ def callback(x):
349375
max_new_tokens,
350376
draft_model=draft_model,
351377
speculate_k=speculate_k,
378+
use_sdpa=use_sdpa,
352379
interactive=interactive,
353380
callback=callback,
354381
temperature=temperature,
@@ -363,7 +390,7 @@ def callback(x):
363390
prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
364391
else:
365392
prof.export_chrome_trace(f"{profile}.json")
366-
torch.cuda.synchronize()
393+
device_sync(device=device) # MKG
367394
t = time.perf_counter() - t0
368395

369396
if not interactive:
@@ -402,9 +429,12 @@ def callback(x):
402429
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
403430
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
404431
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
432+
parser.add_argument('--use_sdpa', action='store_true', help='Whether to use SDPA')
433+
parser.add_argument('--device', type=str, default="cuda", help='device to use')
405434

406435
args = parser.parse_args()
407436
main(
408437
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
409-
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, args.speculate_k
438+
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
439+
args.speculate_k, args.use_sdpa, args.device
410440
)

0 commit comments

Comments
 (0)