Skip to content

Commit d3bd171

Browse files
authored
[Benchmark] Support benchmark throughput for external launcher DP (vllm-project#25913)
Signed-off-by: Zhuohan Li <[email protected]>
1 parent 89e4050 commit d3bd171

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

vllm/benchmarks/throughput.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,23 @@ def get_requests(args, tokenizer):
358358
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
359359
# Remove None values
360360
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
361-
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
361+
requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
362+
requests = filter_requests_for_dp(requests, args.data_parallel_size)
363+
return requests
364+
365+
366+
def filter_requests_for_dp(requests, data_parallel_size):
367+
# Note(zhuohan): The way we get data_parallel_rank is hacky and only
368+
# works for external launcher mode. Should be cleaned up and deprecated
369+
# in the future with a better vLLM distributed process design.
370+
if data_parallel_size == 1:
371+
return requests
372+
373+
global_rank = int(os.environ["RANK"])
374+
world_size = int(os.environ["WORLD_SIZE"])
375+
data_parallel_rank = global_rank // (world_size // data_parallel_size)
376+
return [r for i, r in enumerate(requests)
377+
if i % data_parallel_size == data_parallel_rank]
362378

363379

364380
def validate_args(args):
@@ -453,12 +469,17 @@ def validate_args(args):
453469
if args.backend == "mii" and args.tokenizer != args.model:
454470
raise ValueError(
455471
"Tokenizer must be the same as the model for MII backend.")
456-
457-
# --data-parallel is not supported currently.
458-
# https://github.com/vllm-project/vllm/issues/16222
459-
if args.data_parallel_size > 1:
472+
473+
if args.data_parallel_size > 1 and (
474+
args.distributed_executor_backend != "external_launcher"
475+
or args.async_engine):
476+
# --data-parallel is not supported fully.
477+
# Old issue: https://github.com/vllm-project/vllm/issues/16222
478+
# Currently we only support data parallel with external launcher
479+
# mode (i.e., launch with toruchrun).
460480
raise ValueError(
461-
"Data parallel is not supported in offline benchmark, "
481+
"Data parallel is only supported with external launcher mode "
482+
"with synchronous engine in offline benchmark, "
462483
"please use benchmark serving instead"
463484
)
464485

0 commit comments

Comments
 (0)