@@ -189,14 +189,14 @@ def _schedule_prefill(self):
189189 copy_map : Dict [int , int ] = dict ()
190190 running : SeqList = []
191191 token_count = 0
192+ prealloc_size = self .num_spec_tokens or self .num_spec_tokens - 1
192193
193194 def _to_running (seq : SchedulerSequence ):
194195 """To running."""
195196 seq .status = MessageStatus .RUNNING
196197 running .append (seq )
197198 nonlocal token_count
198199 token_count += seq .num_token_ids
199- token_count += self .num_spec_tokens
200200 token_count += len (seq .spec_token_ids )
201201
202202 def __evict_for_seq (seq : SchedulerSequence , waiting ):
@@ -205,7 +205,7 @@ def __evict_for_seq(seq: SchedulerSequence, waiting):
205205 hanging = reversed (self .hanging )
206206 waiting = reversed (waiting )
207207 evictable = list (chain (hanging , waiting ))
208- return eviction_helper .evict_for_seq (seq , evictable , prealloc_size = self . num_spec_tokens )
208+ return eviction_helper .evict_for_seq (seq , evictable , prealloc_size = prealloc_size )
209209
210210 def _reorder_waiting ():
211211 """Reorder waiting."""
@@ -218,7 +218,7 @@ def _reorder_waiting():
218218 waiting = _reorder_waiting ()
219219 while len (waiting ) > 0 and len (running ) < max_batches :
220220 seq = waiting .pop (0 )
221- cur_token_count = token_count + seq .num_token_ids + self . num_spec_tokens + len (seq .spec_token_ids )
221+ cur_token_count = token_count + seq .num_token_ids + len (seq .spec_token_ids )
222222 if (len (running ) > 0 and cur_token_count > self .cache_config .max_prefill_token_num ):
223223 break
224224
@@ -228,7 +228,7 @@ def _reorder_waiting():
228228 break
229229
230230 # allocate session memory
231- self .block_manager .allocate (seq , prealloc_size = self . num_spec_tokens )
231+ self .block_manager .allocate (seq , prealloc_size = prealloc_size )
232232 _to_running (seq )
233233
234234 seq .record_event (EventType .SCHEDULED )
0 commit comments