Skip to content

Commit b6d162f

Browse files
committed
padding bugfix
Signed-off-by: Sage Moore <[email protected]>
1 parent 0c54343 commit b6d162f

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

vllm/forward_context.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,27 +93,32 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int,
9393
return num_tokens_tensor
9494

9595
@staticmethod
96-
def should_ubatch_across_dp(should_ubatch: bool, num_tokens_per_ubatch: int, dp_size: int,
96+
def should_ubatch_across_dp(should_ubatch: bool, orig_num_tokens_per_ubatch: int,
97+
padded_num_tokens_per_ubatch: int, dp_size: int,
9798
dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
9899

99-
tensor = torch.zeros(3, dp_size, device="cpu", dtype=torch.int32)
100-
tensor[0][dp_rank] = num_tokens_per_ubatch
101-
tensor[1][dp_rank] = 1 if should_ubatch else 0
100+
tensor = torch.zeros(3, dp_size, device="cuda", dtype=torch.int32)
101+
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
102+
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
103+
tensor[2][dp_rank] = 1 if should_ubatch else 0
102104

103105

104106
from vllm.distributed.parallel_state import get_dp_group
105-
dist.all_reduce(tensor, group=get_dp_group().cpu_group)
107+
dist.all_reduce(tensor, group=get_dp_group().device_group)
106108

107-
result: bool = bool(torch.all(tensor[1]== 1).item())
109+
result: bool = bool(torch.all(tensor[2]== 1).item())
108110
if not result:
109111
return result, None
110112

111-
min_num_tokens_per_ubatch = tensor[0].min().item()
112-
max_num_tokens_per_ubatch = tensor[0].max().item()
113-
if max_num_tokens_per_ubatch >= 2 * min_num_tokens_per_ubatch:
114-
logger.debug(f"Aborting ubatching {min_num_tokens_per_ubatch} {max_num_tokens_per_ubatch}")
113+
orig_num_tokens_tensor = tensor[0, :]
114+
padded_num_tokens_tensor = tensor[1, :]
115+
116+
orig_min_num_tokens = orig_num_tokens_tensor.min().item()
117+
padded_max_num_tokens = padded_num_tokens_tensor.max().item()
118+
if padded_max_num_tokens >= 2 * orig_min_num_tokens:
119+
logger.debug(f"Aborting ubatching {orig_min_num_tokens} {padded_max_num_tokens}")
115120
return False, None
116-
return result, tensor[0, :]
121+
return result, padded_num_tokens_tensor
117122

118123
@staticmethod
119124
def make(

vllm/v1/worker/gpu_model_runner.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,7 +1539,7 @@ def get_dp_padding_ubatch(
15391539
if ubatch_slices is None:
15401540
(should_ubatch,
15411541
num_tokens_across_dp) = self.should_ubatch_with_num_tokens(
1542-
False, 0)
1542+
False, 0, 0)
15431543
assert should_ubatch is False
15441544
assert num_tokens_across_dp is None
15451545
return should_ubatch, 0, num_tokens_across_dp
@@ -1581,9 +1581,9 @@ def get_dp_padding_ubatch(
15811581
should_ubatch = False
15821582

15831583
# Note that we compute the number of padded tokens per ubatch
1584-
(should_ubatch,
1585-
num_tokens_across_dp) = self.should_ubatch_with_num_tokens(
1586-
should_ubatch, num_tokens_per_ubatch)
1584+
(should_ubatch,
1585+
num_tokens_across_dp) = self.should_ubatch_with_num_tokens(should_ubatch,
1586+
num_tokens_unpadded // 2, num_tokens_per_ubatch)
15871587
if not should_ubatch:
15881588
assert num_tokens_across_dp is None
15891589
return should_ubatch, 0, num_tokens_across_dp
@@ -1607,7 +1607,7 @@ def get_dp_padding_ubatch(
16071607
def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices,
16081608
num_pad_tokens: int):
16091609
original_num_tokens = ubatch_slices[1].token_slice.stop
1610-
assert num_pad_tokens < original_num_tokens
1610+
assert num_pad_tokens < original_num_tokens, f"num_pad_tokens {num_pad_tokens} original_num_tokens {original_num_tokens}"
16111611
total_num_tokens_per_ubatch = (original_num_tokens +
16121612
num_pad_tokens) // 2
16131613
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
@@ -1631,16 +1631,16 @@ def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices,
16311631
ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice,
16321632
padded_second_ubatch_slice)
16331633

1634-
def should_ubatch_with_num_tokens(
1635-
self,
1636-
should_ubatch: bool,
1637-
num_tokens_per_ubatch: int,
1638-
) -> tuple[bool, Optional[torch.Tensor]]:
1634+
def should_ubatch_with_num_tokens(self, should_ubatch: bool, orig_num_tokens_per_ubatch: int,
1635+
padded_num_tokens_per_ubatch: int,
1636+
) -> tuple[bool, Optional[torch.Tensor]]:
16391637
dp_size = self.vllm_config.parallel_config.data_parallel_size
16401638
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
1641-
return DPMetadata.should_ubatch_across_dp(should_ubatch,
1642-
num_tokens_per_ubatch,
1643-
dp_size, dp_rank)
1639+
return DPMetadata.should_ubatch_across_dp(should_ubatch,
1640+
orig_num_tokens_per_ubatch,
1641+
padded_num_tokens_per_ubatch,
1642+
dp_size,
1643+
dp_rank)
16441644

16451645
def _pool(
16461646
self,
@@ -2472,6 +2472,7 @@ def _dummy_run(
24722472
should_ubatch, _ = self.should_ubatch_with_num_tokens(
24732473
should_ubatch,
24742474
num_tokens // 2,
2475+
num_tokens // 2,
24752476
)
24762477
assert cudagraph_runtime_mode in {
24772478
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL

0 commit comments

Comments
 (0)