23
23
from vllm .compilation .counter import compilation_counter
24
24
from vllm .compilation .cuda_graph import CUDAGraphWrapper
25
25
from vllm .compilation .monitor import set_cudagraph_capturing_enabled
26
+ from vllm .compilation .ubatch_utils import UbatchSlice , UBatchSlices
26
27
from vllm .compilation .ubatch_wrapper import UBatchWrapper
27
- from vllm .compilation .ubatch_utils import (UbatchSlice , UBatchSlices )
28
28
from vllm .config import (CompilationLevel , CUDAGraphMode , VllmConfig ,
29
29
get_layers_from_vllm_config , update_config )
30
30
from vllm .distributed .eplb .eplb_state import EplbState
60
60
from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
61
61
from vllm .v1 .attention .backends .utils import (
62
62
AttentionCGSupport , AttentionMetadataBuilder , CommonAttentionMetadata ,
63
- UbatchSlice , create_fast_prefill_custom_backend ,
63
+ create_fast_prefill_custom_backend ,
64
64
reorder_batch_to_split_decodes_and_prefills , split_attn_metadata )
65
65
from vllm .v1 .cudagraph_dispatcher import CudagraphDispatcher
66
66
from vllm .v1 .kv_cache_interface import (AttentionSpec ,
@@ -605,8 +605,8 @@ def _ubatch_split(
605
605
num_pad_tokens = 0
606
606
num_tokens_after_padding = None
607
607
(should_ubatch , num_pad_tokens ,
608
- num_tokens_after_padding ) = self .get_dp_padding_ubatch (total_num_scheduled_tokens ,
609
- should_attempt_ubatching )
608
+ num_tokens_after_padding ) = self .get_dp_padding_ubatch (
609
+ total_num_scheduled_tokens , should_attempt_ubatching )
610
610
if not should_ubatch :
611
611
return (None , 0 , None )
612
612
@@ -1570,16 +1570,16 @@ def get_dp_padding_ubatch(
1570
1570
should_ubatch = False
1571
1571
1572
1572
# Note that we compute the number of padded tokens per ubatch
1573
- (should_ubatch ,
1574
- num_tokens_across_dp ) = self .should_ubatch_with_num_tokens (should_ubatch ,
1575
- num_tokens_unpadded // 2 , num_tokens_per_ubatch )
1573
+ (should_ubatch ,
1574
+ num_tokens_across_dp ) = self .should_ubatch_with_num_tokens (
1575
+ should_ubatch , num_tokens_unpadded // 2 , num_tokens_per_ubatch )
1576
1576
if not should_ubatch :
1577
1577
assert num_tokens_across_dp is None
1578
1578
return should_ubatch , 0 , num_tokens_across_dp
1579
1579
1580
1580
assert num_tokens_across_dp is not None
1581
1581
1582
- max_tokens_across_dp_cpu = torch .max (num_tokens_across_dp ).item ()
1582
+ max_tokens_across_dp_cpu = int ( torch .max (num_tokens_across_dp ).item () )
1583
1583
num_tokens_after_padding = torch .tensor ([max_tokens_across_dp_cpu ] *
1584
1584
dp_size ,
1585
1585
device = "cpu" ,
@@ -1594,22 +1594,23 @@ def get_dp_padding_ubatch(
1594
1594
# the second ubatch slice out to the total number of tokens
1595
1595
# (num_tokens + padding)
1596
1596
def pad_out_ubatch_slice (self , ubatch_slices : UBatchSlices ,
1597
- num_total_tokens : int ):
1597
+ num_total_tokens : int ):
1598
1598
padded_second_ubatch_slice = slice (ubatch_slices [1 ].token_slice .start ,
1599
1599
num_total_tokens )
1600
1600
ubatch_slices [1 ] = UbatchSlice (padded_second_ubatch_slice ,
1601
1601
padded_second_ubatch_slice )
1602
1602
1603
- def should_ubatch_with_num_tokens (self , should_ubatch : bool , orig_num_tokens_per_ubatch : int ,
1604
- padded_num_tokens_per_ubatch : int ,
1605
- ) -> tuple [bool , Optional [torch .Tensor ]]:
1603
+ def should_ubatch_with_num_tokens (
1604
+ self ,
1605
+ should_ubatch : bool ,
1606
+ orig_num_tokens_per_ubatch : int ,
1607
+ padded_num_tokens_per_ubatch : int ,
1608
+ ) -> tuple [bool , Optional [torch .Tensor ]]:
1606
1609
dp_size = self .vllm_config .parallel_config .data_parallel_size
1607
1610
dp_rank = self .vllm_config .parallel_config .data_parallel_rank
1608
- return DPMetadata .should_ubatch_across_dp (should_ubatch ,
1609
- orig_num_tokens_per_ubatch ,
1610
- padded_num_tokens_per_ubatch ,
1611
- dp_size ,
1612
- dp_rank )
1611
+ return DPMetadata .should_ubatch_across_dp (
1612
+ should_ubatch , orig_num_tokens_per_ubatch ,
1613
+ padded_num_tokens_per_ubatch , dp_size , dp_rank )
1613
1614
1614
1615
def _pool (
1615
1616
self ,
@@ -2426,23 +2427,26 @@ def _dummy_run(
2426
2427
remove_lora: If False, dummy LoRAs are not destroyed after the run
2427
2428
"""
2428
2429
ubatch_enabled = self .parallel_config .enable_microbatching
2430
+ num_tokens_across_dp = None
2431
+ num_pad = 0
2429
2432
should_ubatch = False
2430
2433
if ubatch_enabled :
2431
2434
should_ubatch = num_tokens >= \
2432
2435
self .parallel_config .microbatching_token_threshold and \
2433
2436
allow_microbatching
2434
- should_ubatch , _ = self .should_ubatch_with_num_tokens (
2435
- should_ubatch ,
2436
- num_tokens // 2 ,
2437
- num_tokens // 2 ,
2438
- )
2437
+
2438
+ (should_ubatch , num_pad ,
2439
+ num_tokens_across_dp ) = self .get_dp_padding_ubatch (
2440
+ num_tokens , should_ubatch )
2441
+
2442
+ # Currently the dummy run should only be ubatching during
2443
+ # cuda graph capture, meaning all DP ranks should already
2444
+ # have the same batch size
2445
+ assert num_pad == 0
2439
2446
assert cudagraph_runtime_mode in {
2440
2447
CUDAGraphMode .NONE , CUDAGraphMode .PIECEWISE , CUDAGraphMode .FULL
2441
2448
}
2442
2449
2443
- # Padding for DP
2444
- num_tokens_across_dp = None
2445
- num_pad = 0
2446
2450
if not should_ubatch :
2447
2451
num_pad , num_tokens_across_dp = self .get_dp_padding (num_tokens )
2448
2452
num_tokens += num_pad
@@ -2497,12 +2501,12 @@ def _dummy_run(
2497
2501
# We only support decode-only cudagraphs
2498
2502
assert num_reqs == num_tokens
2499
2503
assert num_tokens % 2 == 0
2500
- num_tokens_per_ubatch = num_tokens // 2
2501
- dp_size = self .vllm_config .parallel_config .data_parallel_size
2502
- num_tokens_across_dp = torch .tensor ([num_tokens_per_ubatch ] *
2503
- dp_size ,
2504
- device = "cpu" ,
2505
- dtype = torch .int32 )
2504
+ # num_tokens_per_ubatch = num_tokens // 2
2505
+ # dp_size = self.vllm_config.parallel_config.data_parallel_size
2506
+ # num_tokens_across_dp = torch.tensor([num_tokens_per_ubatch] *
2507
+ # dp_size,
2508
+ # device="cpu",
2509
+ # dtype=torch.int32)
2506
2510
ubatch_slices = [
2507
2511
UbatchSlice (slice (0 , num_reqs // 2 ), slice (0 ,
2508
2512
num_tokens // 2 )),
0 commit comments