@@ -62,7 +62,7 @@ def speculative_decoding(
62
62
target_model : GPT ,
63
63
token : torch .Tensor ,
64
64
input_pos : torch .Tensor ,
65
- input_pos_maxp1 : torch . Tensor ,
65
+ input_pos_maxp1 : int ,
66
66
speculative_k : int ,
67
67
** sample_kwargs : Dict [str , Any ],
68
68
) -> torch .Tensor :
@@ -100,7 +100,7 @@ def speculative_decoding(
100
100
# Step 1: Generate candidate tokens using draft model
101
101
# The draft model autoregressively generates k tokens, keeping track of probabilities
102
102
draft_input_pos = input_pos .clone ()
103
- draft_input_pos_maxp1 = input_pos_maxp1 . clone ()
103
+ draft_input_pos_maxp1 = input_pos_maxp1
104
104
draft_tokens , draft_probs = [], []
105
105
draft_token = token
106
106
for idx in range (speculative_k ):
@@ -109,7 +109,7 @@ def speculative_decoding(
109
109
)
110
110
draft_token , draft_prob = sample (logits , ** sample_kwargs )
111
111
draft_input_pos .add_ (1 )
112
- draft_input_pos_maxp1 . add_ ( 1 )
112
+ draft_input_pos_maxp1 += 1
113
113
draft_tokens .append (draft_token )
114
114
draft_probs .append (draft_prob )
115
115
draft_tokens = torch .cat (draft_tokens )
@@ -118,7 +118,7 @@ def speculative_decoding(
118
118
# Feed both original token and draft tokens to get target probabilities
119
119
candidate_tokens = torch .cat ((token , draft_tokens ))
120
120
candidate_input_pos = input_pos + torch .arange (0 , speculative_k + 1 , device = input_pos .device )
121
- candidate_input_pos_maxp1 = input_pos_maxp1 . add ( speculative_k )
121
+ candidate_input_pos_maxp1 = input_pos_maxp1 + speculative_k
122
122
target_logits = target_model (
123
123
idx = candidate_tokens .unsqueeze (0 ), input_pos = candidate_input_pos , input_pos_maxp1 = candidate_input_pos_maxp1
124
124
)
@@ -228,7 +228,10 @@ def generate(
228
228
229
229
# Step 1: Prefill draft and target models with the prompt.
230
230
input_pos = torch .arange (0 , prompt_size , device = device , dtype = torch .int64 )
231
- input_pos_maxp1 = torch .tensor (prompt_size , device = device )
231
+ # We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc.
232
+ input_pos_maxp1 = (
233
+ prompt_size if all (m .__class__ .__name__ != "ThunderModule" for m in target_model .modules ()) else None
234
+ )
232
235
next_token (
233
236
draft_model ,
234
237
input_pos ,
@@ -249,7 +252,7 @@ def generate(
249
252
)
250
253
# Update position trackers after prompt
251
254
input_pos = torch .tensor ([prompt_size ], device = device , dtype = torch .int64 )
252
- input_pos_maxp1 . add_ ( 1 )
255
+ input_pos_maxp1 += 1
253
256
254
257
# Step 2: Main generation loop.
255
258
tokens = []
@@ -289,7 +292,7 @@ def generate(
289
292
290
293
# Update positions for next iteration
291
294
input_pos .add_ (accepted_tokens_len )
292
- input_pos_maxp1 . add_ ( accepted_tokens_len )
295
+ input_pos_maxp1 += accepted_tokens_len
293
296
token = new_tokens [- 1 ].unsqueeze (0 )
294
297
295
298
# Finalize generated sequence
0 commit comments