Skip to content

Commit b2ed6c3

Browse files
committed
misc gpu model runner refactoring
Signed-off-by: Sage Moore <[email protected]>
1 parent 9e1f1af commit b2ed6c3

File tree

2 files changed

+40
-37
lines changed

2 files changed

+40
-37
lines changed

vllm/compilation/ubatch_wrapper.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
logger = init_logger(__name__)
2121

22+
2223
@dataclasses.dataclass
2324
class UbatchMetadata:
2425
context: UBatchContext
@@ -47,7 +48,7 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig,
4748
self.device = device
4849
self.ready_barrier = threading.Barrier(3)
4950

50-
self.cudagraphs = {}
51+
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
5152

5253
self.cudagraph_wrapper = None
5354
self.graph_pool = None
@@ -184,15 +185,12 @@ def _ubatch_thread(results, model, ubatch_metadata):
184185

185186
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
186187
positions, inputs_embeds, intermediate_tensors,
187-
compute_stream, dp_metadata,
188-
batch_descriptor,
188+
compute_stream, dp_metadata, batch_descriptor,
189189
cudagraph_runtime_mode) -> list[UbatchMetadata]:
190190

191191
# Create one forward context per ubatch
192192
forward_contexts = []
193193
for i, ubatch_slice in enumerate(ubatch_slices):
194-
num_tokens = (ubatch_slice.token_slice.stop -
195-
ubatch_slice.token_slice.start)
196194
forward_contexts.append(
197195
create_forward_context(
198196
attn_metadata[i] if attn_metadata is not None else None,
@@ -252,7 +250,8 @@ def __call__(self, *args, **kwargs):
252250

253251
# If there's no ubatching, just run the runnable object
254252
if ubatch_slices is None:
255-
if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE):
253+
if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
254+
CUDAGraphMode.PIECEWISE):
256255
return self.runnable(*args, **kwargs)
257256
else:
258257
assert self.cudagraph_wrapper is not None

vllm/v1/worker/gpu_model_runner.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from vllm.compilation.counter import compilation_counter
2424
from vllm.compilation.cuda_graph import CUDAGraphWrapper
2525
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
26+
from vllm.compilation.ubatch_utils import UbatchSlice, UBatchSlices
2627
from vllm.compilation.ubatch_wrapper import UBatchWrapper
27-
from vllm.compilation.ubatch_utils import (UbatchSlice, UBatchSlices)
2828
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
2929
get_layers_from_vllm_config, update_config)
3030
from vllm.distributed.eplb.eplb_state import EplbState
@@ -60,7 +60,7 @@
6060
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
6161
from vllm.v1.attention.backends.utils import (
6262
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
63-
UbatchSlice, create_fast_prefill_custom_backend,
63+
create_fast_prefill_custom_backend,
6464
reorder_batch_to_split_decodes_and_prefills, split_attn_metadata)
6565
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
6666
from vllm.v1.kv_cache_interface import (AttentionSpec,
@@ -605,8 +605,8 @@ def _ubatch_split(
605605
num_pad_tokens = 0
606606
num_tokens_after_padding = None
607607
(should_ubatch, num_pad_tokens,
608-
num_tokens_after_padding) = self.get_dp_padding_ubatch(total_num_scheduled_tokens,
609-
should_attempt_ubatching)
608+
num_tokens_after_padding) = self.get_dp_padding_ubatch(
609+
total_num_scheduled_tokens, should_attempt_ubatching)
610610
if not should_ubatch:
611611
return (None, 0, None)
612612

@@ -1570,16 +1570,16 @@ def get_dp_padding_ubatch(
15701570
should_ubatch = False
15711571

15721572
# Note that we compute the number of padded tokens per ubatch
1573-
(should_ubatch,
1574-
num_tokens_across_dp) = self.should_ubatch_with_num_tokens(should_ubatch,
1575-
num_tokens_unpadded // 2, num_tokens_per_ubatch)
1573+
(should_ubatch,
1574+
num_tokens_across_dp) = self.should_ubatch_with_num_tokens(
1575+
should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch)
15761576
if not should_ubatch:
15771577
assert num_tokens_across_dp is None
15781578
return should_ubatch, 0, num_tokens_across_dp
15791579

15801580
assert num_tokens_across_dp is not None
15811581

1582-
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
1582+
max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item())
15831583
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
15841584
dp_size,
15851585
device="cpu",
@@ -1594,22 +1594,23 @@ def get_dp_padding_ubatch(
15941594
# the second ubatch slice out to the total number of tokens
15951595
# (num_tokens + padding)
15961596
def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices,
1597-
num_total_tokens: int):
1597+
num_total_tokens: int):
15981598
padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start,
15991599
num_total_tokens)
16001600
ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice,
16011601
padded_second_ubatch_slice)
16021602

1603-
def should_ubatch_with_num_tokens(self, should_ubatch: bool, orig_num_tokens_per_ubatch: int,
1604-
padded_num_tokens_per_ubatch: int,
1605-
) -> tuple[bool, Optional[torch.Tensor]]:
1603+
def should_ubatch_with_num_tokens(
1604+
self,
1605+
should_ubatch: bool,
1606+
orig_num_tokens_per_ubatch: int,
1607+
padded_num_tokens_per_ubatch: int,
1608+
) -> tuple[bool, Optional[torch.Tensor]]:
16061609
dp_size = self.vllm_config.parallel_config.data_parallel_size
16071610
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
1608-
return DPMetadata.should_ubatch_across_dp(should_ubatch,
1609-
orig_num_tokens_per_ubatch,
1610-
padded_num_tokens_per_ubatch,
1611-
dp_size,
1612-
dp_rank)
1611+
return DPMetadata.should_ubatch_across_dp(
1612+
should_ubatch, orig_num_tokens_per_ubatch,
1613+
padded_num_tokens_per_ubatch, dp_size, dp_rank)
16131614

16141615
def _pool(
16151616
self,
@@ -2426,23 +2427,26 @@ def _dummy_run(
24262427
remove_lora: If False, dummy LoRAs are not destroyed after the run
24272428
"""
24282429
ubatch_enabled = self.parallel_config.enable_microbatching
2430+
num_tokens_across_dp = None
2431+
num_pad = 0
24292432
should_ubatch = False
24302433
if ubatch_enabled:
24312434
should_ubatch = num_tokens >= \
24322435
self.parallel_config.microbatching_token_threshold and \
24332436
allow_microbatching
2434-
should_ubatch, _ = self.should_ubatch_with_num_tokens(
2435-
should_ubatch,
2436-
num_tokens // 2,
2437-
num_tokens // 2,
2438-
)
2437+
2438+
(should_ubatch, num_pad,
2439+
num_tokens_across_dp) = self.get_dp_padding_ubatch(
2440+
num_tokens, should_ubatch)
2441+
2442+
# Currently the dummy run should only be ubatching during
2443+
# cuda graph capture, meaning all DP ranks should already
2444+
# have the same batch size
2445+
assert num_pad == 0
24392446
assert cudagraph_runtime_mode in {
24402447
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
24412448
}
24422449

2443-
# Padding for DP
2444-
num_tokens_across_dp = None
2445-
num_pad = 0
24462450
if not should_ubatch:
24472451
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
24482452
num_tokens += num_pad
@@ -2497,12 +2501,12 @@ def _dummy_run(
24972501
# We only support decode-only cudagraphs
24982502
assert num_reqs == num_tokens
24992503
assert num_tokens % 2 == 0
2500-
num_tokens_per_ubatch = num_tokens // 2
2501-
dp_size = self.vllm_config.parallel_config.data_parallel_size
2502-
num_tokens_across_dp = torch.tensor([num_tokens_per_ubatch] *
2503-
dp_size,
2504-
device="cpu",
2505-
dtype=torch.int32)
2504+
# num_tokens_per_ubatch = num_tokens // 2
2505+
# dp_size = self.vllm_config.parallel_config.data_parallel_size
2506+
# num_tokens_across_dp = torch.tensor([num_tokens_per_ubatch] *
2507+
# dp_size,
2508+
# device="cpu",
2509+
# dtype=torch.int32)
25062510
ubatch_slices = [
25072511
UbatchSlice(slice(0, num_reqs // 2), slice(0,
25082512
num_tokens // 2)),

0 commit comments

Comments
 (0)