@@ -1539,7 +1539,7 @@ def get_dp_padding_ubatch(
1539
1539
if ubatch_slices is None :
1540
1540
(should_ubatch ,
1541
1541
num_tokens_across_dp ) = self .should_ubatch_with_num_tokens (
1542
- False , 0 )
1542
+ False , 0 , 0 )
1543
1543
assert should_ubatch is False
1544
1544
assert num_tokens_across_dp is None
1545
1545
return should_ubatch , 0 , num_tokens_across_dp
@@ -1581,9 +1581,9 @@ def get_dp_padding_ubatch(
1581
1581
should_ubatch = False
1582
1582
1583
1583
# Note that we compute the number of padded tokens per ubatch
1584
- (should_ubatch ,
1585
- num_tokens_across_dp ) = self .should_ubatch_with_num_tokens (
1586
- should_ubatch , num_tokens_per_ubatch )
1584
+ (should_ubatch ,
1585
+ num_tokens_across_dp ) = self .should_ubatch_with_num_tokens (should_ubatch ,
1586
+ num_tokens_unpadded // 2 , num_tokens_per_ubatch )
1587
1587
if not should_ubatch :
1588
1588
assert num_tokens_across_dp is None
1589
1589
return should_ubatch , 0 , num_tokens_across_dp
@@ -1607,7 +1607,7 @@ def get_dp_padding_ubatch(
1607
1607
def pad_out_ubatch_first_stage (self , ubatch_slices : UBatchSlices ,
1608
1608
num_pad_tokens : int ):
1609
1609
original_num_tokens = ubatch_slices [1 ].token_slice .stop
1610
- assert num_pad_tokens < original_num_tokens
1610
+ assert num_pad_tokens < original_num_tokens , f"num_pad_tokens { num_pad_tokens } original_num_tokens { original_num_tokens } "
1611
1611
total_num_tokens_per_ubatch = (original_num_tokens +
1612
1612
num_pad_tokens ) // 2
1613
1613
padded_first_ubatch_slice = slice (0 , total_num_tokens_per_ubatch )
@@ -1631,16 +1631,16 @@ def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices,
1631
1631
ubatch_slices [1 ] = UbatchSlice (padded_second_ubatch_slice ,
1632
1632
padded_second_ubatch_slice )
1633
1633
1634
- def should_ubatch_with_num_tokens (
1635
- self ,
1636
- should_ubatch : bool ,
1637
- num_tokens_per_ubatch : int ,
1638
- ) -> tuple [bool , Optional [torch .Tensor ]]:
1634
+ def should_ubatch_with_num_tokens (self , should_ubatch : bool , orig_num_tokens_per_ubatch : int ,
1635
+ padded_num_tokens_per_ubatch : int ,
1636
+ ) -> tuple [bool , Optional [torch .Tensor ]]:
1639
1637
dp_size = self .vllm_config .parallel_config .data_parallel_size
1640
1638
dp_rank = self .vllm_config .parallel_config .data_parallel_rank
1641
- return DPMetadata .should_ubatch_across_dp (should_ubatch ,
1642
- num_tokens_per_ubatch ,
1643
- dp_size , dp_rank )
1639
+ return DPMetadata .should_ubatch_across_dp (should_ubatch ,
1640
+ orig_num_tokens_per_ubatch ,
1641
+ padded_num_tokens_per_ubatch ,
1642
+ dp_size ,
1643
+ dp_rank )
1644
1644
1645
1645
def _pool (
1646
1646
self ,
@@ -2472,6 +2472,7 @@ def _dummy_run(
2472
2472
should_ubatch , _ = self .should_ubatch_with_num_tokens (
2473
2473
should_ubatch ,
2474
2474
num_tokens // 2 ,
2475
+ num_tokens // 2 ,
2475
2476
)
2476
2477
assert cudagraph_runtime_mode in {
2477
2478
CUDAGraphMode .NONE , CUDAGraphMode .PIECEWISE , CUDAGraphMode .FULL
0 commit comments