Skip to content

Commit fcf7fc8

Browse files
authored
Remove use_sdpa option
1 parent c0125d7 commit fcf7fc8

File tree

1 file changed

+7
-23
lines changed

1 file changed

+7
-23
lines changed

generate.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,10 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
6767
logits = model(x, input_pos)
6868
return sample(logits, **sampling_kwargs)
6969

70-
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, use_sdpa=False, callback=lambda _: _, **sampling_kwargs):
70+
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
7171
new_tokens, new_probs = [], []
72-
if not use_sdpa:
73-
for i in range(num_new_tokens):
74-
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
75-
next_token, next_prob = decode_one_token(
76-
model, cur_token, input_pos, **sampling_kwargs
77-
)
78-
input_pos += 1
79-
new_tokens.append(next_token.clone())
80-
callback(new_tokens[-1])
81-
new_probs.append(next_prob.clone())
82-
cur_token = next_token.view(1, -1)
83-
else:
84-
for i in range(num_new_tokens):
72+
for i in range(num_new_tokens):
73+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
8574
next_token, next_prob = decode_one_token(
8675
model, cur_token, input_pos, **sampling_kwargs
8776
)
@@ -103,13 +92,12 @@ def speculative_decode(
10392
cur_token: torch.Tensor,
10493
input_pos: int,
10594
speculate_k: int,
106-
use_sdpa=False,
10795
**sampling_kwargs
10896
) -> torch.Tensor:
10997
# draft model inference sequentially
11098
device = cur_token.device
11199
orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
112-
draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, use_sdpa=use_sdpa, **sampling_kwargs)
100+
draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
113101

114102
draft_tokens = torch.cat(draft_tokens)
115103
# parallel inference on target model using draft tokens
@@ -157,7 +145,6 @@ def generate(
157145
interactive: bool,
158146
draft_model: Transformer,
159147
speculate_k: Optional[int] = 8,
160-
use_sdpa=False,
161148
callback = lambda x: x,
162149
**sampling_kwargs
163150
) -> torch.Tensor:
@@ -201,7 +188,7 @@ def generate(
201188
cur_token = next_token.view(())
202189

203190
next_tokens = speculative_decode(
204-
model, draft_model, cur_token, input_pos, speculate_k, use_sdpa=use_sdpa, **sampling_kwargs
191+
model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
205192
)
206193

207194
accept_counts[len(next_tokens) - 1] += 1
@@ -212,7 +199,7 @@ def generate(
212199
input_pos = input_pos + num_added
213200
next_token = next_tokens[-1]
214201
else:
215-
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, use_sdpa=use_sdpa, callback=callback, **sampling_kwargs)
202+
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
216203
seq[T + 1:] = torch.cat(generated_tokens)
217204

218205
generate_stats = {
@@ -271,7 +258,6 @@ def main(
271258
profile: Optional[Path] = None,
272259
draft_checkpoint_path: Optional[Path] = None,
273260
speculate_k: int = 5,
274-
use_sdpa=False,
275261
device='cuda',
276262
) -> None:
277263
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
@@ -374,7 +360,6 @@ def callback(x):
374360
max_new_tokens,
375361
draft_model=draft_model,
376362
speculate_k=speculate_k,
377-
use_sdpa=use_sdpa,
378363
interactive=interactive,
379364
callback=callback,
380365
temperature=temperature,
@@ -428,12 +413,11 @@ def callback(x):
428413
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
429414
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
430415
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
431-
parser.add_argument('--use_sdpa', action='store_true', help='Whether to use SDPA')
432416
parser.add_argument('--device', type=str, default="cuda", help='device to use')
433417

434418
args = parser.parse_args()
435419
main(
436420
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
437421
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
438-
args.speculate_k, args.use_sdpa, args.device
422+
args.speculate_k, args.device
439423
)

0 commit comments

Comments
 (0)