From 3d2bc2f52289d8f4d66fb1ad0221399fbd1ec7c5 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Fri, 19 Sep 2025 11:13:04 -0700 Subject: [PATCH] support torchrun DP mode --- lm_eval/models/vllm_causallms.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index be442809e31..3718464edf3 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -172,6 +172,7 @@ def __init__( "swap_space": int(swap_space), "quantization": quantization, "seed": int(seed), + "data_parallel_size": int(data_parallel_size), "enable_lora": True if lora_local_path else False, "max_lora_rank": int(max_lora_rank), } @@ -181,8 +182,18 @@ def __init__( if isinstance(batch_size, str) and "auto" in batch_size else int(batch_size) ) + + self.dp_group = None + self.torch_dist = None + self.is_external_launcher_dp = self.model_args.get("distributed_executor_backend", None) == "external_launcher" and self.data_parallel_size > 1 if self.data_parallel_size <= 1: self.model = LLM(**self.model_args) + elif self.is_external_launcher_dp: + self.model = LLM(**self.model_args) + from vllm.distributed.parallel_state import get_dp_group + import torch.distributed as dist + self.dp_group = get_dp_group() + self.torch_dist = dist else: eval_logger.warning( "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached." @@ -417,6 +428,22 @@ def run_inference_one_model( list(sp) for sp in distribute(self.data_parallel_size, sampling_params) ) procs, resq = [], Queue() + if self.is_external_launcher_dp: + dp_rank = self.model.llm_engine.vllm_config.parallel_config.data_parallel_rank + local_requests = list(requests)[dp_rank] + local_sampling_params = list(sampling_params)[dp_rank] + local_results = self.model.generate( + [TokensPrompt(prompt_token_ids=request) for request in local_requests], + sampling_params=local_sampling_params, + use_tqdm=True if self.batch_size == "auto" else False, + ) + # All gather results across data parallel group + assert self.dp_group is not None + assert self.torch_dist is not None + # Gather results from all DP ranks + gathered_results = [None] * self.dp_group.world_size + self.torch_dist.all_gather_object(gathered_results, local_results, group=self.dp_group.cpu_group) + return undistribute(gathered_results) # We use Process as it is non-daemonic try: for rank, (sp, req) in enumerate(zip(requests, sampling_params)):