Skip to content

Commit 2682bb7

Browse files
Squashed commit of nm/lwilkinson/dbo-alt-schedules changes relative to origin/main
fixes Signed-off-by: Lucas Wilkinson <[email protected]> fix Signed-off-by: Lucas Wilkinson <[email protected]> fix Signed-off-by: Lucas Wilkinson <[email protected]> fixes and formatting Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent a2e6fa7 commit 2682bb7

File tree

14 files changed

+654
-485
lines changed

14 files changed

+654
-485
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,4 +259,4 @@ def start(rank):
259259
elif proc.exitcode:
260260
exit_code = proc.exitcode
261261

262-
exit(exit_code)
262+
exit(exit_code)

vllm/config/parallel.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ class ParallelConfig:
151151
prefills. If the number of tokens in the request is greater than this
152152
threshold, microbatching will be used. Otherwise, the request will be
153153
processed in a single batch."""
154+
microbatch_schedule: Literal["mlp_shared_overlap", "attn_shared_overlap"] =\
155+
"mlp_shared_overlap"
156+
"""Schedule policy for microbatch overlap coordination.
157+
158+
Options:
159+
- "mlp_shared_overlap": overlap MLP and communication across ubatches
160+
- "attn_shared_overlap": overlap MLA attention and communication across
161+
ubatches
162+
see: vllm/v1/worker/ubatching.py for diagrams of the schedules.
163+
"""
154164

155165
ray_workers_use_nsight: bool = False
156166
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

vllm/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ class EngineArgs:
334334
ParallelConfig.dbo_decode_token_threshold
335335
dbo_prefill_token_threshold: int = \
336336
ParallelConfig.dbo_prefill_token_threshold
337+
microbatch_schedule: str = ParallelConfig.microbatch_schedule
337338
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
338339
enable_eplb: bool = ParallelConfig.enable_eplb
339340
expert_placement_strategy: ExpertPlacementStrategy = \
@@ -705,6 +706,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
705706
parallel_group.add_argument(
706707
"--dbo-prefill-token-threshold",
707708
**parallel_kwargs["dbo_prefill_token_threshold"])
709+
parallel_group.add_argument(
710+
"--microbatch-schedule",
711+
dest="microbatch_schedule",
712+
**parallel_kwargs["microbatch_schedule"])
708713
parallel_group.add_argument("--enable-eplb",
709714
**parallel_kwargs["enable_eplb"])
710715
parallel_group.add_argument("--eplb-config",
@@ -1329,6 +1334,7 @@ def create_engine_config(
13291334
enable_dbo=self.enable_dbo,
13301335
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
13311336
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1337+
microbatch_schedule=self.microbatch_schedule,
13321338
enable_eplb=self.enable_eplb,
13331339
eplb_config=self.eplb_config,
13341340
expert_placement_strategy=self.expert_placement_strategy,

vllm/forward_context.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,34 @@ def should_ubatch_across_dp(
171171
return False, None
172172
return result, padded_num_tokens_tensor.cpu()
173173

174+
@staticmethod
175+
def should_ubatch_across_dp(should_ubatch: bool, orig_num_tokens_per_ubatch: int,
176+
padded_num_tokens_per_ubatch: int, dp_size: int,
177+
dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
178+
179+
tensor = torch.zeros(3, dp_size, device="cuda", dtype=torch.int32)
180+
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
181+
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
182+
tensor[2][dp_rank] = 1 if should_ubatch else 0
183+
184+
185+
from vllm.distributed.parallel_state import get_dp_group
186+
dist.all_reduce(tensor, group=get_dp_group().device_group)
187+
188+
result: bool = bool(torch.all(tensor[2]== 1).item())
189+
if not result:
190+
return result, None
191+
192+
orig_num_tokens_tensor = tensor[0, :]
193+
padded_num_tokens_tensor = tensor[1, :]
194+
195+
orig_min_num_tokens = orig_num_tokens_tensor.min().item()
196+
padded_max_num_tokens = padded_num_tokens_tensor.max().item()
197+
if padded_max_num_tokens >= 2 * orig_min_num_tokens:
198+
logger.debug(f"Aborting ubatching {orig_min_num_tokens} {padded_max_num_tokens}")
199+
return False, None
200+
return result, padded_num_tokens_tensor
201+
174202
@staticmethod
175203
def make(
176204
parallel_config: ParallelConfig,
@@ -199,6 +227,7 @@ def make(
199227
if num_tokens_across_dp_cpu is None:
200228
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
201229
batchsize, dp_size, dp_rank)
230+
202231
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
203232
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
204233

0 commit comments

Comments
 (0)