11import copy
2+ import gc
23import inspect
34import logging
5+ import os
46from importlib .metadata import version
57from importlib .util import find_spec
8+ from multiprocessing import Process , Queue
9+ from queue import Empty
10+ from time import sleep
611from typing import TYPE_CHECKING , Dict , List , Literal , Optional , Tuple , Union
712
813from more_itertools import distribute
2934 from vllm import LLM , SamplingParams
3035 from vllm .lora .request import LoRARequest
3136 from vllm .transformers_utils .tokenizer import get_tokenizer
37+ from vllm .utils import get_open_port
3238
3339 if parse_version (version ("vllm" )) >= parse_version ("0.8.3" ):
3440 from vllm .entrypoints .chat_utils import resolve_hf_chat_template
4147eval_logger = logging .getLogger (__name__ )
4248
4349
50+ def _vllm_mp_worker (
51+ model_args : dict ,
52+ sampling_params : "SamplingParams" ,
53+ requests : list [list [int ]],
54+ lora_request : "LoRARequest" ,
55+ result_queue : "Queue" ,
56+ dp_size : int ,
57+ local_dp_rank : int ,
58+ dp_master_port : int ,
59+ dp_master_ip : str = "127.0.0.1" ,
60+ ) -> None :
61+ """
62+ Worker process for vLLM multiprocessing.
63+ Initializes a vLLM engine, processes requests, and puts results or errors
64+ onto the result_queue.
65+ """
66+
67+ if not requests :
68+ result_queue .put ((local_dp_rank , []))
69+ return None
70+
71+ os .environ ["VLLM_DP_RANK" ] = os .environ ["VLLM_DP_RANK_LOCAL" ] = str (local_dp_rank )
72+ os .environ ["VLLM_DP_SIZE" ] = str (dp_size )
73+ os .environ ["VLLM_DP_MASTER_IP" ] = str (dp_master_ip )
74+ os .environ ["VLLM_DP_MASTER_PORT" ] = str (dp_master_port )
75+
76+ llm = None
77+ try :
78+ llm = LLM (** model_args )
79+ res = llm .generate (
80+ prompt_token_ids = requests ,
81+ sampling_params = sampling_params ,
82+ lora_request = lora_request ,
83+ )
84+ # Give engines time to pause their processing loops before exiting."
85+ sleep (1 )
86+ result_queue .put ((local_dp_rank , res ))
87+
88+ except Exception as e :
89+ error_message = f"Worker { local_dp_rank } failed during generation: { type (e ).__name__ } : { str (e )} "
90+ eval_logger .error (error_message , exc_info = True )
91+ result_queue .put ((local_dp_rank , {"error" : error_message }))
92+
93+ finally :
94+ if llm is not None :
95+ try :
96+ del llm
97+ gc .collect ()
98+ except Exception as e_cleanup :
99+ eval_logger .warning (
100+ f"Worker { local_dp_rank } encountered an error during LLM cleanup: { type (e_cleanup ).__name__ } : { str (e_cleanup )} " ,
101+ exc_info = True ,
102+ )
103+
104+ return None
105+
106+
44107@register_model ("vllm" )
45108class VLLM (TemplateLM ):
46109 _DEFAULT_MAX_LENGTH = 2048
@@ -83,7 +146,7 @@ def __init__(
83146 assert max_length is None or max_model_len is None , (
84147 "Either max_length or max_model_len may be provided, but not both"
85148 )
86-
149+ self . V1 = os . environ . get ( "VLLM_USE_V1" , "1" ) != "0"
87150 self ._max_length = max_model_len if max_model_len is not None else max_length
88151 self .tensor_parallel_size = int (tensor_parallel_size )
89152 self .data_parallel_size = int (data_parallel_size )
@@ -98,6 +161,7 @@ def __init__(
98161 "trust_remote_code" : trust_remote_code ,
99162 "tensor_parallel_size" : int (tensor_parallel_size ),
100163 "max_model_len" : int (self ._max_length ) if self ._max_length else None ,
164+ "max_num_seqs" : kwargs .get ("max_num_seqs" , max_batch_size ),
101165 "swap_space" : int (swap_space ),
102166 "quantization" : quantization ,
103167 "seed" : int (seed ),
@@ -115,7 +179,11 @@ def __init__(
115179 eval_logger .warning (
116180 "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
117181 )
118- self .model_args ["distributed_executor_backend" ] = "ray"
182+ self .model_args ["distributed_executor_backend" ] = (
183+ "ray"
184+ if not self .V1
185+ else self .model_args .get ("distributed_executor_backend" , None )
186+ )
119187 self .batch_size = "auto"
120188 eval_logger .info ("Manual batching is not compatible with data parallelism." )
121189
@@ -279,7 +347,7 @@ def _model_generate(
279347 sampling_params = SamplingParams (
280348 temperature = 0 , prompt_logprobs = 1 , max_tokens = 1 , detokenize = False
281349 )
282- if self .data_parallel_size > 1 :
350+ if self .data_parallel_size > 1 and not self . V1 :
283351 # vLLM hangs if resources are set in ray.remote
284352 # also seems to only work with decorator and not with ray.remote() fn
285353 # see https://github.com/vllm-project/vllm/issues/973
@@ -310,14 +378,83 @@ def run_inference_one_model(
310378 ray .shutdown ()
311379 # flatten results
312380 return undistribute (results )
381+ elif self .data_parallel_size > 1 :
382+ # based on https://github.com/vllm-project/vllm/blob/a04720bc36401d831cb048c3917b9e58173d9c1d/examples/offline_inference/data_parallel.py
383+ dp_size = self .data_parallel_size
384+ dp_master_ip = os .environ .get ("VLLM_DP_MASTER_IP" , "127.0.0.1" )
385+ dp_master_port = os .environ .get ("VLLM_DP_MASTER_PORT" ) or get_open_port ()
386+
387+ requests = (list (x ) for x in distribute (self .data_parallel_size , requests ))
388+
389+ procs , resq = [], Queue ()
390+ # We use Process as it is non-daemonic
391+ try :
392+ for rank , req in enumerate (requests ):
393+ proc = Process (
394+ target = _vllm_mp_worker ,
395+ args = (
396+ self .model_args .copy (),
397+ sampling_params ,
398+ req ,
399+ self .lora_request ,
400+ resq ,
401+ dp_size ,
402+ rank ,
403+ dp_master_port ,
404+ dp_master_ip ,
405+ ),
406+ )
407+ proc .start ()
408+ procs .append (proc )
409+
410+ # Collect results
411+ rank_res = {}
412+ while len (rank_res ) < len (procs ):
413+ try :
414+ rank , result = resq .get (timeout = 30 )
415+ if isinstance (result , dict ) and "error" in result :
416+ raise RuntimeError (result ["error" ])
417+ rank_res [rank ] = result
418+ except Empty :
419+ dead_procs = [
420+ idx
421+ for idx , p in enumerate (procs )
422+ if not p .is_alive () and idx not in rank_res
423+ ]
424+ if dead_procs :
425+ raise RuntimeError (
426+ f"Worker processes { dead_procs } died unexpectedly"
427+ )
428+ continue
429+
430+ results = [rank_res [i ] for i in range (len (procs ))]
431+ return undistribute (results )
432+
433+ # cleanup
434+ finally :
435+ try :
436+ resq .close ()
437+ resq .join_thread ()
438+ except Exception :
439+ eval_logger .debug (
440+ "Failed to close vllm DP results queue" , exc_info = True
441+ )
442+ for proc in procs :
443+ proc .join (timeout = 10 )
444+ if proc .is_alive ():
445+ proc .terminate ()
446+ proc .join (timeout = 5 )
447+ if proc .is_alive ():
448+ proc .kill ()
313449
314- outputs = self .model .generate (
315- prompt_token_ids = requests ,
316- sampling_params = sampling_params ,
317- use_tqdm = True if self .batch_size == "auto" else False ,
318- lora_request = self .lora_request ,
319- )
320- return outputs
450+ else :
451+ outputs = self .model .generate (
452+ prompt_token_ids = requests ,
453+ sampling_params = sampling_params ,
454+ use_tqdm = True if self .batch_size == "auto" else False ,
455+ lora_request = self .lora_request ,
456+ )
457+ return outputs
321458
322459 def loglikelihood_rolling (
323460 self , requests : List [Instance ], disable_tqdm : bool = False
@@ -507,8 +644,7 @@ def _collate(x):
507644 for cache_key , context_enc , continuation_enc in chunk :
508645 if (
509646 full_length := len (context_enc + continuation_enc )
510- >= self .max_length
511- ):
647+ ) > self .max_length :
512648 eval_logger .warning (
513649 f"Context length { full_length } exceeds max length ({ self .max_length } ). Truncating context."
514650 )
0 commit comments