Skip to content

Commit 7427b2d

Browse files
committed
simplify ubatch padding
Signed-off-by: Sage Moore <[email protected]>
1 parent 44124af commit 7427b2d

File tree

1 file changed

+23
-55
lines changed

1 file changed

+23
-55
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 23 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -602,32 +602,31 @@ def _ubatch_split(
602602
self.parallel_config.microbatching_token_threshold \
603603
and max_num_scheduled_tokens == 1
604604

605-
# For pure decode we can just create ubatches by cutting the request
606-
# in half
607-
ubatch_slices = None
608-
if should_attempt_ubatching:
609-
b0_reqs_end = num_reqs // 2
610-
b0_tokens_end = total_num_scheduled_tokens // 2
611-
assert b0_reqs_end < num_reqs and \
612-
b0_tokens_end < total_num_scheduled_tokens
613-
ubatch_slices = [
614-
UbatchSlice(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
615-
UbatchSlice(slice(b0_reqs_end, num_reqs),
616-
slice(b0_tokens_end, total_num_scheduled_tokens)),
617-
]
618-
619605
# Don't microbatch unless every other DP worker is also microbatching
620606
num_pad_tokens = 0
621607
num_tokens_after_padding = None
622608
(should_ubatch, num_pad_tokens,
623-
num_tokens_after_padding) = self.get_dp_padding_ubatch(ubatch_slices)
609+
num_tokens_after_padding) = self.get_dp_padding_ubatch(total_num_scheduled_tokens,
610+
should_attempt_ubatching)
624611
if not should_ubatch:
625612
return (None, 0, None)
626-
assert ubatch_slices
627613

628-
# Compute ubatch padding. This currently only accounts for DP padding
629-
if num_pad_tokens > 0:
630-
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
614+
# This doesn't actually pad the ubatch slices. It just initialize the
615+
# split point to the correct value so that padding can be applied
616+
# to the second ubatch in pad_out_ubatch_slice after attention
617+
# metadata creation
618+
assert num_pad_tokens < total_num_scheduled_tokens, f"num_pad_tokens {num_pad_tokens} original_num_tokens {total_num_scheduled_tokens}"
619+
total_num_tokens_per_ubatch = (total_num_scheduled_tokens +
620+
num_pad_tokens) // 2
621+
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
622+
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
623+
total_num_scheduled_tokens)
624+
625+
# Note there's an assumption here that there's 1 token per request
626+
ubatch_slices = [
627+
UbatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice),
628+
UbatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice)
629+
]
631630

632631
return (ubatch_slices, num_pad_tokens, num_tokens_after_padding)
633632

@@ -1528,34 +1527,23 @@ def get_padding(
15281527
return num_dp_pad_tokens + num_pad_tokens, num_tokens_after_padding
15291528

15301529
def get_dp_padding_ubatch(
1531-
self, ubatch_slices: Optional[UBatchSlices]
1530+
self, total_num_scheduled_tokens: int, should_attempt_ubatching: bool
15321531
) -> tuple[bool, int, Optional[torch.Tensor]]:
15331532
dp_size = self.vllm_config.parallel_config.data_parallel_size
15341533

15351534
if dp_size == 1:
15361535
# Early exit.
15371536
return False, 0, None
15381537

1539-
if ubatch_slices is None:
1538+
if not should_attempt_ubatching:
15401539
(should_ubatch,
15411540
num_tokens_across_dp) = self.should_ubatch_with_num_tokens(
15421541
False, 0, 0)
15431542
assert should_ubatch is False
15441543
assert num_tokens_across_dp is None
15451544
return should_ubatch, 0, num_tokens_across_dp
15461545

1547-
first_ubatch_slice = ubatch_slices[0]
1548-
second_ubatch_slice = ubatch_slices[1]
1549-
1550-
first_ubatch_num_tokens = first_ubatch_slice.token_slice.stop - \
1551-
first_ubatch_slice.token_slice.start
1552-
second_ubatch_num_tokens = second_ubatch_slice.token_slice.stop - \
1553-
second_ubatch_slice.token_slice.start
1554-
# We don't support prefills yet so the two ubatches should only differ
1555-
# by at most one token
1556-
assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1
1557-
1558-
num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens
1546+
num_tokens_unpadded = total_num_scheduled_tokens
15591547
num_tokens_padded = round_up(num_tokens_unpadded, 2)
15601548
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
15611549
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
@@ -1600,32 +1588,12 @@ def get_dp_padding_ubatch(
16001588
num_tokens_unpadded
16011589
return should_ubatch, num_pad_tokens, num_tokens_after_padding
16021590

1603-
# This doesn't actually pad the ubatch slices. It just shifts the
1604-
# split point to the correct value so that padding can be applied
1605-
# to the second ubatch in pad_out_ubatch_second_stage. Should be
1606-
# called after ubatch slicing but before attention meta data creation
1607-
def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices,
1608-
num_pad_tokens: int):
1609-
original_num_tokens = ubatch_slices[1].token_slice.stop
1610-
assert num_pad_tokens < original_num_tokens, f"num_pad_tokens {num_pad_tokens} original_num_tokens {original_num_tokens}"
1611-
total_num_tokens_per_ubatch = (original_num_tokens +
1612-
num_pad_tokens) // 2
1613-
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
1614-
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
1615-
original_num_tokens)
1616-
1617-
ubatch_slices[0] = UbatchSlice(padded_first_ubatch_slice,
1618-
padded_first_ubatch_slice)
1619-
ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice,
1620-
padded_second_ubatch_slice)
1621-
16221591
# This is where the second ubatch is adjusted to account for the padding.
16231592
# Should be called after attention metadata creation. This just pads
16241593
# the second ubatch slice out to the total number of tokens
16251594
# (num_tokens + padding)
1626-
def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices,
1595+
def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices,
16271596
num_total_tokens: int):
1628-
# TODO Add asserts to make sure stage one ran
16291597
padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start,
16301598
num_total_tokens)
16311599
ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice,
@@ -1712,7 +1680,7 @@ def execute_model(
17121680
num_input_tokens = num_scheduled_tokens
17131681
if ubatch_slices and num_pad_tokens > 0:
17141682
num_input_tokens += num_pad_tokens
1715-
self.pad_out_ubatch_second_stage(ubatch_slices, num_input_tokens)
1683+
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
17161684
elif ubatch_slices is None:
17171685
num_pad, num_tokens_after_padding = self.get_padding(
17181686
num_input_tokens)

0 commit comments

Comments
 (0)