Skip to content

Commit 5a481f4

Browse files
authored
[vllm] data parallel for V1 (#3011)
* add data_parallel for V1 * use Process instead of Queue * ray used if V0 DP * better error handling * fix truncation warning comparison
1 parent 7aaceee commit 5a481f4

File tree

1 file changed

+148
-12
lines changed

1 file changed

+148
-12
lines changed

lm_eval/models/vllm_causallms.py

Lines changed: 148 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import copy
2+
import gc
23
import inspect
34
import logging
5+
import os
46
from importlib.metadata import version
57
from importlib.util import find_spec
8+
from multiprocessing import Process, Queue
9+
from queue import Empty
10+
from time import sleep
611
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
712

813
from more_itertools import distribute
@@ -29,6 +34,7 @@
2934
from vllm import LLM, SamplingParams
3035
from vllm.lora.request import LoRARequest
3136
from vllm.transformers_utils.tokenizer import get_tokenizer
37+
from vllm.utils import get_open_port
3238

3339
if parse_version(version("vllm")) >= parse_version("0.8.3"):
3440
from vllm.entrypoints.chat_utils import resolve_hf_chat_template
@@ -41,6 +47,63 @@
4147
eval_logger = logging.getLogger(__name__)
4248

4349

50+
def _vllm_mp_worker(
51+
model_args: dict,
52+
sampling_params: "SamplingParams",
53+
requests: list[list[int]],
54+
lora_request: "LoRARequest",
55+
result_queue: "Queue",
56+
dp_size: int,
57+
local_dp_rank: int,
58+
dp_master_port: int,
59+
dp_master_ip: str = "127.0.0.1",
60+
) -> None:
61+
"""
62+
Worker process for vLLM multiprocessing.
63+
Initializes a vLLM engine, processes requests, and puts results or errors
64+
onto the result_queue.
65+
"""
66+
67+
if not requests:
68+
result_queue.put((local_dp_rank, []))
69+
return None
70+
71+
os.environ["VLLM_DP_RANK"] = os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
72+
os.environ["VLLM_DP_SIZE"] = str(dp_size)
73+
os.environ["VLLM_DP_MASTER_IP"] = str(dp_master_ip)
74+
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
75+
76+
llm = None
77+
try:
78+
llm = LLM(**model_args)
79+
res = llm.generate(
80+
prompt_token_ids=requests,
81+
sampling_params=sampling_params,
82+
lora_request=lora_request,
83+
)
84+
# Give engines time to pause their processing loops before exiting."
85+
sleep(1)
86+
result_queue.put((local_dp_rank, res))
87+
88+
except Exception as e:
89+
error_message = f"Worker {local_dp_rank} failed during generation: {type(e).__name__}: {str(e)}"
90+
eval_logger.error(error_message, exc_info=True)
91+
result_queue.put((local_dp_rank, {"error": error_message}))
92+
93+
finally:
94+
if llm is not None:
95+
try:
96+
del llm
97+
gc.collect()
98+
except Exception as e_cleanup:
99+
eval_logger.warning(
100+
f"Worker {local_dp_rank} encountered an error during LLM cleanup: {type(e_cleanup).__name__}: {str(e_cleanup)}",
101+
exc_info=True,
102+
)
103+
104+
return None
105+
106+
44107
@register_model("vllm")
45108
class VLLM(TemplateLM):
46109
_DEFAULT_MAX_LENGTH = 2048
@@ -83,7 +146,7 @@ def __init__(
83146
assert max_length is None or max_model_len is None, (
84147
"Either max_length or max_model_len may be provided, but not both"
85148
)
86-
149+
self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0"
87150
self._max_length = max_model_len if max_model_len is not None else max_length
88151
self.tensor_parallel_size = int(tensor_parallel_size)
89152
self.data_parallel_size = int(data_parallel_size)
@@ -98,6 +161,7 @@ def __init__(
98161
"trust_remote_code": trust_remote_code,
99162
"tensor_parallel_size": int(tensor_parallel_size),
100163
"max_model_len": int(self._max_length) if self._max_length else None,
164+
"max_num_seqs": kwargs.get("max_num_seqs", max_batch_size),
101165
"swap_space": int(swap_space),
102166
"quantization": quantization,
103167
"seed": int(seed),
@@ -115,7 +179,11 @@ def __init__(
115179
eval_logger.warning(
116180
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
117181
)
118-
self.model_args["distributed_executor_backend"] = "ray"
182+
self.model_args["distributed_executor_backend"] = (
183+
"ray"
184+
if not self.V1
185+
else self.model_args.get("distributed_executor_backend", None)
186+
)
119187
self.batch_size = "auto"
120188
eval_logger.info("Manual batching is not compatible with data parallelism.")
121189

@@ -279,7 +347,7 @@ def _model_generate(
279347
sampling_params = SamplingParams(
280348
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
281349
)
282-
if self.data_parallel_size > 1:
350+
if self.data_parallel_size > 1 and not self.V1:
283351
# vLLM hangs if resources are set in ray.remote
284352
# also seems to only work with decorator and not with ray.remote() fn
285353
# see https://github.com/vllm-project/vllm/issues/973
@@ -310,14 +378,83 @@ def run_inference_one_model(
310378
ray.shutdown()
311379
# flatten results
312380
return undistribute(results)
381+
elif self.data_parallel_size > 1:
382+
# based on https://github.com/vllm-project/vllm/blob/a04720bc36401d831cb048c3917b9e58173d9c1d/examples/offline_inference/data_parallel.py
383+
dp_size = self.data_parallel_size
384+
dp_master_ip = os.environ.get("VLLM_DP_MASTER_IP", "127.0.0.1")
385+
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
386+
387+
requests = (list(x) for x in distribute(self.data_parallel_size, requests))
388+
389+
procs, resq = [], Queue()
390+
# We use Process as it is non-daemonic
391+
try:
392+
for rank, req in enumerate(requests):
393+
proc = Process(
394+
target=_vllm_mp_worker,
395+
args=(
396+
self.model_args.copy(),
397+
sampling_params,
398+
req,
399+
self.lora_request,
400+
resq,
401+
dp_size,
402+
rank,
403+
dp_master_port,
404+
dp_master_ip,
405+
),
406+
)
407+
proc.start()
408+
procs.append(proc)
409+
410+
# Collect results
411+
rank_res = {}
412+
while len(rank_res) < len(procs):
413+
try:
414+
rank, result = resq.get(timeout=30)
415+
if isinstance(result, dict) and "error" in result:
416+
raise RuntimeError(result["error"])
417+
rank_res[rank] = result
418+
except Empty:
419+
dead_procs = [
420+
idx
421+
for idx, p in enumerate(procs)
422+
if not p.is_alive() and idx not in rank_res
423+
]
424+
if dead_procs:
425+
raise RuntimeError(
426+
f"Worker processes {dead_procs} died unexpectedly"
427+
)
428+
continue
429+
430+
results = [rank_res[i] for i in range(len(procs))]
431+
return undistribute(results)
432+
433+
# cleanup
434+
finally:
435+
try:
436+
resq.close()
437+
resq.join_thread()
438+
except Exception:
439+
eval_logger.debug(
440+
"Failed to close vllm DP results queue", exc_info=True
441+
)
442+
for proc in procs:
443+
proc.join(timeout=10)
444+
if proc.is_alive():
445+
proc.terminate()
446+
proc.join(timeout=5)
447+
if proc.is_alive():
448+
proc.kill()
313449

314-
outputs = self.model.generate(
315-
prompt_token_ids=requests,
316-
sampling_params=sampling_params,
317-
use_tqdm=True if self.batch_size == "auto" else False,
318-
lora_request=self.lora_request,
319-
)
320-
return outputs
450+
else:
451+
outputs = self.model.generate(
452+
prompt_token_ids=requests,
453+
sampling_params=sampling_params,
454+
use_tqdm=True if self.batch_size == "auto" else False,
455+
lora_request=self.lora_request,
456+
)
457+
return outputs
321458

322459
def loglikelihood_rolling(
323460
self, requests: List[Instance], disable_tqdm: bool = False
@@ -507,8 +644,7 @@ def _collate(x):
507644
for cache_key, context_enc, continuation_enc in chunk:
508645
if (
509646
full_length := len(context_enc + continuation_enc)
510-
>= self.max_length
511-
):
647+
) > self.max_length:
512648
eval_logger.warning(
513649
f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
514650
)

0 commit comments

Comments
 (0)