@@ -602,32 +602,31 @@ def _ubatch_split(
602
602
self .parallel_config .microbatching_token_threshold \
603
603
and max_num_scheduled_tokens == 1
604
604
605
- # For pure decode we can just create ubatches by cutting the request
606
- # in half
607
- ubatch_slices = None
608
- if should_attempt_ubatching :
609
- b0_reqs_end = num_reqs // 2
610
- b0_tokens_end = total_num_scheduled_tokens // 2
611
- assert b0_reqs_end < num_reqs and \
612
- b0_tokens_end < total_num_scheduled_tokens
613
- ubatch_slices = [
614
- UbatchSlice (slice (0 , b0_reqs_end ), slice (0 , b0_tokens_end )),
615
- UbatchSlice (slice (b0_reqs_end , num_reqs ),
616
- slice (b0_tokens_end , total_num_scheduled_tokens )),
617
- ]
618
-
619
605
# Don't microbatch unless every other DP worker is also microbatching
620
606
num_pad_tokens = 0
621
607
num_tokens_after_padding = None
622
608
(should_ubatch , num_pad_tokens ,
623
- num_tokens_after_padding ) = self .get_dp_padding_ubatch (ubatch_slices )
609
+ num_tokens_after_padding ) = self .get_dp_padding_ubatch (total_num_scheduled_tokens ,
610
+ should_attempt_ubatching )
624
611
if not should_ubatch :
625
612
return (None , 0 , None )
626
- assert ubatch_slices
627
613
628
- # Compute ubatch padding. This currently only accounts for DP padding
629
- if num_pad_tokens > 0 :
630
- self .pad_out_ubatch_first_stage (ubatch_slices , num_pad_tokens )
614
+ # This doesn't actually pad the ubatch slices. It just initialize the
615
+ # split point to the correct value so that padding can be applied
616
+ # to the second ubatch in pad_out_ubatch_slice after attention
617
+ # metadata creation
618
+ assert num_pad_tokens < total_num_scheduled_tokens , f"num_pad_tokens { num_pad_tokens } original_num_tokens { total_num_scheduled_tokens } "
619
+ total_num_tokens_per_ubatch = (total_num_scheduled_tokens +
620
+ num_pad_tokens ) // 2
621
+ padded_first_ubatch_slice = slice (0 , total_num_tokens_per_ubatch )
622
+ padded_second_ubatch_slice = slice (total_num_tokens_per_ubatch ,
623
+ total_num_scheduled_tokens )
624
+
625
+ # Note there's an assumption here that there's 1 token per request
626
+ ubatch_slices = [
627
+ UbatchSlice (padded_first_ubatch_slice , padded_first_ubatch_slice ),
628
+ UbatchSlice (padded_second_ubatch_slice , padded_second_ubatch_slice )
629
+ ]
631
630
632
631
return (ubatch_slices , num_pad_tokens , num_tokens_after_padding )
633
632
@@ -1528,34 +1527,23 @@ def get_padding(
1528
1527
return num_dp_pad_tokens + num_pad_tokens , num_tokens_after_padding
1529
1528
1530
1529
def get_dp_padding_ubatch (
1531
- self , ubatch_slices : Optional [ UBatchSlices ]
1530
+ self , total_num_scheduled_tokens : int , should_attempt_ubatching : bool
1532
1531
) -> tuple [bool , int , Optional [torch .Tensor ]]:
1533
1532
dp_size = self .vllm_config .parallel_config .data_parallel_size
1534
1533
1535
1534
if dp_size == 1 :
1536
1535
# Early exit.
1537
1536
return False , 0 , None
1538
1537
1539
- if ubatch_slices is None :
1538
+ if not should_attempt_ubatching :
1540
1539
(should_ubatch ,
1541
1540
num_tokens_across_dp ) = self .should_ubatch_with_num_tokens (
1542
1541
False , 0 , 0 )
1543
1542
assert should_ubatch is False
1544
1543
assert num_tokens_across_dp is None
1545
1544
return should_ubatch , 0 , num_tokens_across_dp
1546
1545
1547
- first_ubatch_slice = ubatch_slices [0 ]
1548
- second_ubatch_slice = ubatch_slices [1 ]
1549
-
1550
- first_ubatch_num_tokens = first_ubatch_slice .token_slice .stop - \
1551
- first_ubatch_slice .token_slice .start
1552
- second_ubatch_num_tokens = second_ubatch_slice .token_slice .stop - \
1553
- second_ubatch_slice .token_slice .start
1554
- # We don't support prefills yet so the two ubatches should only differ
1555
- # by at most one token
1556
- assert abs (first_ubatch_num_tokens - second_ubatch_num_tokens ) <= 1
1557
-
1558
- num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens
1546
+ num_tokens_unpadded = total_num_scheduled_tokens
1559
1547
num_tokens_padded = round_up (num_tokens_unpadded , 2 )
1560
1548
if (self .compilation_config .cudagraph_mode != CUDAGraphMode .NONE
1561
1549
and num_tokens_unpadded <= self .cudagraph_batch_sizes [- 1 ]):
@@ -1600,32 +1588,12 @@ def get_dp_padding_ubatch(
1600
1588
num_tokens_unpadded
1601
1589
return should_ubatch , num_pad_tokens , num_tokens_after_padding
1602
1590
1603
- # This doesn't actually pad the ubatch slices. It just shifts the
1604
- # split point to the correct value so that padding can be applied
1605
- # to the second ubatch in pad_out_ubatch_second_stage. Should be
1606
- # called after ubatch slicing but before attention meta data creation
1607
- def pad_out_ubatch_first_stage (self , ubatch_slices : UBatchSlices ,
1608
- num_pad_tokens : int ):
1609
- original_num_tokens = ubatch_slices [1 ].token_slice .stop
1610
- assert num_pad_tokens < original_num_tokens , f"num_pad_tokens { num_pad_tokens } original_num_tokens { original_num_tokens } "
1611
- total_num_tokens_per_ubatch = (original_num_tokens +
1612
- num_pad_tokens ) // 2
1613
- padded_first_ubatch_slice = slice (0 , total_num_tokens_per_ubatch )
1614
- padded_second_ubatch_slice = slice (total_num_tokens_per_ubatch ,
1615
- original_num_tokens )
1616
-
1617
- ubatch_slices [0 ] = UbatchSlice (padded_first_ubatch_slice ,
1618
- padded_first_ubatch_slice )
1619
- ubatch_slices [1 ] = UbatchSlice (padded_second_ubatch_slice ,
1620
- padded_second_ubatch_slice )
1621
-
1622
1591
# This is where the second ubatch is adjusted to account for the padding.
1623
1592
# Should be called after attention metadata creation. This just pads
1624
1593
# the second ubatch slice out to the total number of tokens
1625
1594
# (num_tokens + padding)
1626
- def pad_out_ubatch_second_stage (self , ubatch_slices : UBatchSlices ,
1595
+ def pad_out_ubatch_slice (self , ubatch_slices : UBatchSlices ,
1627
1596
num_total_tokens : int ):
1628
- # TODO Add asserts to make sure stage one ran
1629
1597
padded_second_ubatch_slice = slice (ubatch_slices [1 ].token_slice .start ,
1630
1598
num_total_tokens )
1631
1599
ubatch_slices [1 ] = UbatchSlice (padded_second_ubatch_slice ,
@@ -1712,7 +1680,7 @@ def execute_model(
1712
1680
num_input_tokens = num_scheduled_tokens
1713
1681
if ubatch_slices and num_pad_tokens > 0 :
1714
1682
num_input_tokens += num_pad_tokens
1715
- self .pad_out_ubatch_second_stage (ubatch_slices , num_input_tokens )
1683
+ self .pad_out_ubatch_slice (ubatch_slices , num_input_tokens )
1716
1684
elif ubatch_slices is None :
1717
1685
num_pad , num_tokens_after_padding = self .get_padding (
1718
1686
num_input_tokens )
0 commit comments