forked from ModelEngine-Group/unified-cache-management
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvllm_patch.py
More file actions
1992 lines (1755 loc) · 87.1 KB
/
vllm_patch.py
File metadata and controls
1992 lines (1755 loc) · 87.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# MIT License
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
from __future__ import annotations
import os
from ucm.logger import init_logger
logger = init_logger(__name__)
ENABLE_SPARSE = os.getenv("ENABLE_SPARSE")
def _enable_sparse() -> bool:
return ENABLE_SPARSE is not None and ENABLE_SPARSE.lower() == "true"
def _apply_sparse_adapt() -> None:
"""Apply sparse adapt patches."""
try:
if _enable_sparse():
_patch_block_table()
_patch_kv_cache_manager()
_patch_shared_storage_connector()
_patch_attention_layer()
_patch_mla_common()
_patch_gpu_model_runner()
_patch_gpu_worker()
_patch_scheduler_output()
_patch_scheduler()
logger.info("UCM sparse adapt patches applied successfully")
except Exception as e:
logger.error(f"Could not apply sparse adapt patches: {e}")
raise e
# ==================== vllm/v1/core/sched/output.py ====================
def _patch_scheduler_output() -> None:
"""Patch scheduler output to add UCM sparse support."""
try:
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
)
from vllm.v1.core.sched import output
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData
@dataclass
class SchedulerOutput:
# list of the requests that are scheduled for the first time.
# We cache the request's data in each worker process, so that we don't
# need to re-send it every scheduling step.
scheduled_new_reqs: list[NewRequestData]
# list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs: CachedRequestData
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.
num_scheduled_tokens: dict[str, int]
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
# included in the dictionary.
scheduled_spec_decode_tokens: dict[str, list[int]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs: dict[str, list[int]]
# Number of common prefix blocks for all requests in each KV cache group.
# This can be used for cascade attention.
num_common_prefix_blocks: list[int]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids: set[str]
# list of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache.
free_encoder_input_ids: list[tuple[str, int]]
# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]
# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None
# modified slots by sparse algorithm
req_sparsed_slots: dict[str, int] = None
# Set module and qualname to make the class pickleable
# This ensures pickle can find the class when serializing
SchedulerOutput.__module__ = output.__name__
SchedulerOutput.__qualname__ = "SchedulerOutput"
output.SchedulerOutput = SchedulerOutput
except ImportError:
logger.warning("Could not patch scheduler output - module not found")
# ==================== vllm/attention/layer.py ====================
def _patch_attention_layer() -> None:
"""Patch attention layer & unified_attention_with_output C++ op."""
try:
from typing import Optional
import torch
from vllm.attention.layer import (
maybe_save_kv_layer_to_connector,
wait_for_kv_layer_from_connector,
)
from vllm.forward_context import ForwardContext, get_forward_context
from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
def maybe_execute_sparse_attention_begin(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
forward_context: ForwardContext,
phase: Optional[str] = None,
):
if not has_ucm_sparse():
return
ucm_sparse = get_ucm_sparse()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
ucm_sparse.attention_begin(
query, key, value, layer_name, forward_context, phase
)
def maybe_execute_sparse_attention_finished(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_output: torch.Tensor,
layer_name: str,
forward_context: ForwardContext,
phase: Optional[str] = None,
):
if not has_ucm_sparse():
return
ucm_sparse = get_ucm_sparse()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
ucm_sparse.attention_finished(
query, key, value, attn_output, layer_name, forward_context, phase
)
vllm_ops = torch.ops.vllm
orig_unified_attention_with_output = vllm_ops.unified_attention_with_output
orig_unified_attention = vllm_ops.unified_attention
def _wrap_op_overload(orig, impl):
class _Wrapper:
def __init__(self, orig):
self._orig = orig
def __call__(self, *args, **kwargs):
return impl(*args, **kwargs)
def __getattr__(self, name):
return getattr(self._orig, name)
return _Wrapper(orig)
def unified_attention_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
maybe_execute_sparse_attention_begin(
query, key, value, layer_name, forward_context
)
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
maybe_execute_sparse_attention_finished(
query, key, value, output, layer_name, forward_context
)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
def unified_attention_with_output_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
if not self.use_mla:
maybe_execute_sparse_attention_begin(
query, key, value, layer_name, forward_context
)
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
)
if not self.use_mla:
maybe_execute_sparse_attention_finished(
query, key, value, output, layer_name, forward_context
)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
vllm_ops.unified_attention_with_output = _wrap_op_overload(
orig_unified_attention_with_output, unified_attention_with_output_impl
)
vllm_ops.unified_attention = _wrap_op_overload(
orig_unified_attention, unified_attention_impl
)
from vllm.attention import layer
layer.maybe_execute_sparse_attention_begin = (
maybe_execute_sparse_attention_begin
)
layer.maybe_execute_sparse_attention_finished = (
maybe_execute_sparse_attention_finished
)
layer.unified_attention = unified_attention_impl
layer.unified_attention_with_output = unified_attention_with_output_impl
except ImportError:
logger.warning(
"Could not patch unified attention with output - module not found"
)
# ==================== v1/shared_storage_connector.py ====================
def _patch_shared_storage_connector() -> None:
"""Patch kv connector utils to add UCM sparse support."""
try:
from dataclasses import dataclass, field
from vllm.distributed.kv_transfer.kv_connector.v1 import (
shared_storage_connector,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
ReqMeta,
)
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list)
def add_request(
self,
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
) -> None:
self.requests.append(
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store)
)
shared_storage_connector.SharedStorageConnectorMetadata = (
SharedStorageConnectorMetadata
)
except ImportError:
logger.warning("Could not patch shared storage connector - module not found")
# ==================== vllm/v1/attention/backends/mla/common.py ====================
def _patch_mla_common() -> None:
"""Patch mla common to add UCM sparse support."""
try:
from typing import Optional, TypeVar
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionLayer
from vllm.attention.layer import (
maybe_execute_sparse_attention_begin,
maybe_execute_sparse_attention_finished,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.v1.attention.backends.mla.common import (
MLACommonImpl,
MLACommonMetadata,
)
M = TypeVar("M", bound=MLACommonMetadata)
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: M,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLACommonImpl"
)
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
assert (
attn_metadata.num_decodes is not None
and attn_metadata.num_prefills is not None
and attn_metadata.num_decode_tokens is not None
)
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
if has_prefill:
maybe_execute_sparse_attention_begin(
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
layer.layer_name,
forward_context,
"prefill",
)
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata
)
maybe_execute_sparse_attention_finished(
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
output[num_decode_tokens:],
layer.layer_name,
forward_context,
"prefill",
)
if has_decode:
assert attn_metadata.decode is not None
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
maybe_execute_sparse_attention_begin(
torch.cat([decode_ql_nope, decode_q_pe], dim=-1),
decode_ql_nope,
decode_q_pe,
layer.layer_name,
forward_context,
"decode",
)
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata
)
maybe_execute_sparse_attention_finished(
torch.cat([decode_ql_nope, decode_q_pe], dim=-1),
decode_ql_nope,
decode_q_pe,
output[:num_decode_tokens],
layer.layer_name,
forward_context,
"decode",
)
return output_padded
MLACommonImpl.forward = forward
except ImportError:
logger.warning("Could not patch mla common - module not found")
# ==================== v1/core/kv_cache_manager.py ====================
def _patch_kv_cache_manager() -> None:
"""Patch kv cache manager to add UCM sparse support."""
try:
from typing import Optional, Union
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.request import Request
from ucm.sparse.base import INVALID_SLOT
from ucm.sparse.state import get_ucm_sparse
original_allocate_slots = KVCacheManager.allocate_slots
def patched_allocate_slots(
self,
request: Request,
num_new_tokens: int,
num_new_computed_tokens: int = 0,
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_draft_tokens: int = 0,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
num_slots_sparsed: Union[None, int] = None,
) -> Optional[KVCacheBlocks]:
if num_new_tokens == 0:
raise ValueError("num_new_tokens must be greater than 0")
# Only route to UCM sparse path when caller explicitly provided
# a valid sparsified slot count.
if (num_slots_sparsed is not None) and (num_slots_sparsed != INVALID_SLOT):
return get_ucm_sparse().allocate_slots(self, request, num_slots_sparsed)
return original_allocate_slots(
self,
request,
num_new_tokens,
num_new_computed_tokens,
new_computed_blocks,
num_draft_tokens,
num_lookahead_tokens,
delay_cache_blocks,
)
KVCacheManager.allocate_slots = patched_allocate_slots
except ImportError:
logger.warning("Could not patch kv cache manager - module not found")
# ==================== vllm/v1/core/sched/scheduler.py ====================
def _patch_scheduler() -> None:
"""Patch Scheduler to add num_output_tokens field."""
try:
import itertools
import time
from collections import defaultdict
from collections.abc import Iterable
from typing import Optional
from vllm.distributed.kv_events import KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
)
from vllm.v1.core.sched.output import (
CachedRequestData,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.core.sched.request_queue import (
SchedulingPolicy,
create_request_queue,
)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (
EngineCoreEventType,
EngineCoreOutput,
EngineCoreOutputs,
)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from ucm.sparse.base import INVALID_SLOT, UcmSparseRole
from ucm.sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse
def init_ucm_sparse(self):
self.ucm_sparse = None
if self.vllm_config.kv_transfer_config is not None:
if (
"ucm_sparse_config"
in self.vllm_config.kv_transfer_config.kv_connector_extra_config
):
ensure_ucm_sparse_initialized(
self.vllm_config, role=UcmSparseRole.SCHEDULER
)
self.ucm_sparse = get_ucm_sparse()
logger.info(
"UCM Sparse initialized successfully: {}".format(
self.ucm_sparse
)
)
def patched_schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to the running request index.
# This will helps us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# For logging.
scheduled_timestamp = time.monotonic()
# First, schedule the RUNNING requests.
req_index = 0
req_sparsed_slots: dict[str, int] = {}
if not hasattr(self, "ucm_sparse"):
init_ucm_sparse(self)
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
num_slots_sparsed = INVALID_SLOT
if self.ucm_sparse:
num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(
request
)
req_sparsed_slots.update({request.request_id: num_slots_sparsed})
num_new_tokens = (
request.num_tokens_with_spec - request.num_computed_tokens
)
if (
0
< self.scheduler_config.long_prefill_token_threshold
< num_new_tokens
):
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget) = (
self._try_schedule_encoder_inputs(
request,
request.num_computed_tokens,
num_new_tokens,
encoder_budget,
)
)
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when PP>1 and
# we have already scheduled all prompt tokens but they are
# not finished yet.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
req_index += 1
continue
num_draft_tokens = max(
num_new_tokens + request.num_computed_tokens - request.num_tokens, 0
)
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
num_slots_sparsed=num_slots_sparsed,
)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp
)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
# The request can be scheduled.
can_schedule = True
break
if not can_schedule:
break
assert new_blocks is not None
# Schedule the request.
scheduled_running_reqs.append(request)
if request.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = new_blocks.get_block_ids()
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (
num_new_tokens
+ request.num_computed_tokens
- request.num_tokens
)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids
)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule
)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id
for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0
)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
request = self.waiting.peek_request()
num_slots_sparsed = INVALID_SLOT
if self.ucm_sparse:
num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(
request
)
req_sparsed_slots.update({request.request_id: num_slots_sparsed})
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id,
)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
if (
self.lora_config
and request.lora_request
and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id not in scheduled_loras
)
):
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_external_computed_tokens = 0
load_kv_async = False
# Get already-cached tokens.
if request.num_computed_tokens == 0:
# Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = (
self.kv_cache_manager.get_computed_blocks(request)
)
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens
)
)
# Total computed tokens (local + external).
num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens
)
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else:
new_computed_blocks = (
self.kv_cache_manager.create_empty_block_list()
)
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
num_new_tokens = 0
# Number of tokens to be scheduled.
else:
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (
0
< self.scheduler_config.long_prefill_token_threshold
< num_new_tokens
):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold
)
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if (
not self.scheduler_config.chunked_prefill_enabled
and num_new_tokens > token_budget
):
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_budget,
) = self._try_schedule_encoder_inputs(
request,
num_computed_tokens,
num_new_tokens,
encoder_budget,
)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async,
num_slots_sparsed=num_slots_sparsed,
)
if new_blocks is None:
# The request cannot be scheduled.
break
# KVTransfer: the connector uses this info to determine
# if a load is needed. Note that
# This information is used to determine if a load is
# needed for this request.
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
if request.use_structured_output:
structured_output_request_ids[request.request_id] = req_index
req_index += 1
self.running.append(request)
if self.log_stats:
request.record_event(
EngineCoreEventType.SCHEDULED, scheduled_timestamp
)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id)
)
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule
)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than
# len(self.running).
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
scheduled_running_reqs
) <= len(self.running)
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)
)
)
grammar_bitmask = self.structured_output_manager.grammar_bitmask(