@@ -67,21 +67,10 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
67
67
logits = model (x , input_pos )
68
68
return sample (logits , ** sampling_kwargs )
69
69
70
- def decode_n_tokens (model : Transformer , cur_token : torch .Tensor , input_pos : torch .Tensor , num_new_tokens : int , use_sdpa = False , callback = lambda _ : _ , ** sampling_kwargs ):
70
+ def decode_n_tokens (model : Transformer , cur_token : torch .Tensor , input_pos : torch .Tensor , num_new_tokens : int , callback = lambda _ : _ , ** sampling_kwargs ):
71
71
new_tokens , new_probs = [], []
72
- if not use_sdpa :
73
- for i in range (num_new_tokens ):
74
- with torch .backends .cuda .sdp_kernel (enable_flash = False , enable_mem_efficient = False , enable_math = True ): # Actually better for Inductor to codegen attention here
75
- next_token , next_prob = decode_one_token (
76
- model , cur_token , input_pos , ** sampling_kwargs
77
- )
78
- input_pos += 1
79
- new_tokens .append (next_token .clone ())
80
- callback (new_tokens [- 1 ])
81
- new_probs .append (next_prob .clone ())
82
- cur_token = next_token .view (1 , - 1 )
83
- else :
84
- for i in range (num_new_tokens ):
72
+ for i in range (num_new_tokens ):
73
+ with torch .backends .cuda .sdp_kernel (enable_flash = False , enable_mem_efficient = False , enable_math = True ): # Actually better for Inductor to codegen attention here
85
74
next_token , next_prob = decode_one_token (
86
75
model , cur_token , input_pos , ** sampling_kwargs
87
76
)
@@ -103,13 +92,12 @@ def speculative_decode(
103
92
cur_token : torch .Tensor ,
104
93
input_pos : int ,
105
94
speculate_k : int ,
106
- use_sdpa = False ,
107
95
** sampling_kwargs
108
96
) -> torch .Tensor :
109
97
# draft model inference sequentially
110
98
device = cur_token .device
111
99
orig_input_pos = torch .tensor ([input_pos ], dtype = torch .int64 , device = cur_token .device )
112
- draft_tokens , draft_probs = decode_n_tokens (draft_model , cur_token .view (1 , - 1 ), orig_input_pos .clone (), speculate_k , use_sdpa = use_sdpa , ** sampling_kwargs )
100
+ draft_tokens , draft_probs = decode_n_tokens (draft_model , cur_token .view (1 , - 1 ), orig_input_pos .clone (), speculate_k , ** sampling_kwargs )
113
101
114
102
draft_tokens = torch .cat (draft_tokens )
115
103
# parallel inference on target model using draft tokens
@@ -157,7 +145,6 @@ def generate(
157
145
interactive : bool ,
158
146
draft_model : Transformer ,
159
147
speculate_k : Optional [int ] = 8 ,
160
- use_sdpa = False ,
161
148
callback = lambda x : x ,
162
149
** sampling_kwargs
163
150
) -> torch .Tensor :
@@ -201,7 +188,7 @@ def generate(
201
188
cur_token = next_token .view (())
202
189
203
190
next_tokens = speculative_decode (
204
- model , draft_model , cur_token , input_pos , speculate_k , use_sdpa = use_sdpa , ** sampling_kwargs
191
+ model , draft_model , cur_token , input_pos , speculate_k , ** sampling_kwargs
205
192
)
206
193
207
194
accept_counts [len (next_tokens ) - 1 ] += 1
@@ -212,7 +199,7 @@ def generate(
212
199
input_pos = input_pos + num_added
213
200
next_token = next_tokens [- 1 ]
214
201
else :
215
- generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , use_sdpa = use_sdpa , callback = callback , ** sampling_kwargs )
202
+ generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
216
203
seq [T + 1 :] = torch .cat (generated_tokens )
217
204
218
205
generate_stats = {
@@ -271,7 +258,6 @@ def main(
271
258
profile : Optional [Path ] = None ,
272
259
draft_checkpoint_path : Optional [Path ] = None ,
273
260
speculate_k : int = 5 ,
274
- use_sdpa = False ,
275
261
device = 'cuda' ,
276
262
) -> None :
277
263
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
@@ -374,7 +360,6 @@ def callback(x):
374
360
max_new_tokens ,
375
361
draft_model = draft_model ,
376
362
speculate_k = speculate_k ,
377
- use_sdpa = use_sdpa ,
378
363
interactive = interactive ,
379
364
callback = callback ,
380
365
temperature = temperature ,
@@ -428,12 +413,11 @@ def callback(x):
428
413
parser .add_argument ('--profile' , type = Path , default = None , help = 'Profile path.' )
429
414
parser .add_argument ('--speculate_k' , type = int , default = 5 , help = 'Speculative execution depth.' )
430
415
parser .add_argument ('--draft_checkpoint_path' , type = Path , default = None , help = 'Draft checkpoint path.' )
431
- parser .add_argument ('--use_sdpa' , action = 'store_true' , help = 'Whether to use SDPA' )
432
416
parser .add_argument ('--device' , type = str , default = "cuda" , help = 'device to use' )
433
417
434
418
args = parser .parse_args ()
435
419
main (
436
420
args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
437
421
args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path ,
438
- args .speculate_k , args .use_sdpa , args . device
422
+ args .speculate_k , args .device
439
423
)
0 commit comments