Skip to content

Commit 6edd75c

Browse files
authored
fix vllm compatibility issue (#417)
* Use vllm/sglang server wrapper for local launching and fix an NCCL issue with vllm * fix
1 parent 73127a2 commit 6edd75c

File tree

7 files changed

+138
-90
lines changed

7 files changed

+138
-90
lines changed

areal/engine/base_hf_engine.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,6 @@ def parallelism_group(self) -> dist.ProcessGroup:
117117

118118
def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
119119
backend = current_platform.communication_backend
120-
if current_platform.communication_backend == "nccl":
121-
# Required by NCCL weight update group for SGLang
122-
os.environ["NCCL_CUMEM_ENABLE"] = "0"
123-
os.environ["NCCL_NVLS_ENABLE"] = "0"
124120
if not dist.is_initialized():
125121
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
126122
# NOTE: device_id **SHOULD NOT** be passed into init_process_group,

areal/engine/sglang_remote.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,19 @@
2626
from areal.platforms import current_platform
2727
from areal.utils import logging, name_resolve, names
2828
from areal.utils.http import arequest_with_retry, get_default_connector
29+
from areal.utils.launcher import wait_llm_server_addrs
2930

3031
RID_CACHE_SIZE = 128
3132

3233

3334
class RemoteSGLangEngine(InferenceEngine):
3435

3536
def __init__(self, config: InferenceEngineConfig):
37+
if current_platform.communication_backend == "nccl":
38+
# Required by NCCL weight update group.
39+
os.environ["NCCL_CUMEM_ENABLE"] = "0"
40+
os.environ["NCCL_NVLS_ENABLE"] = "0"
41+
3642
self.config = config
3743

3844
self.rid_to_address = {}
@@ -83,9 +89,24 @@ def initialize(
8389

8490
if addr:
8591
self.addresses = addr if isinstance(addr, list) else [addr]
92+
self.logger.info(f"Get server addresses from the `addr` argument.")
8693
else:
94+
if (
95+
self.config.experiment_name is not None
96+
and self.config.trial_name is not None
97+
):
98+
try:
99+
self.addresses = wait_llm_server_addrs(
100+
experiment_name=self.config.experiment_name,
101+
trial_name=self.config.trial_name,
102+
timeout=1,
103+
)
104+
self.logger.info(f"Get server addresses from name_resolve.")
105+
except TimeoutError:
106+
pass
107+
if not self.addresses and os.getenv("AREAL_LLM_SERVER_ADDRS"):
87108
# When addr is not provided, fallback to reading addrs from env var
88-
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
109+
self.addresses = os.environ["AREAL_LLM_SERVER_ADDRS"].split(",")
89110
if not self.addresses:
90111
raise RuntimeError(
91112
"No configured SGLang servers. Please pass in SGLang server addresses by arguments "

areal/engine/vllm_remote.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from areal.platforms import current_platform
2727
from areal.utils import logging, name_resolve, names
2828
from areal.utils.http import arequest_with_retry, get_default_connector
29+
from areal.utils.launcher import wait_llm_server_addrs
2930

3031
RID_CACHE_SIZE = 128
3132

@@ -90,9 +91,24 @@ def initialize(
9091

9192
if addr:
9293
self.addresses = addr if isinstance(addr, list) else [addr]
94+
self.logger.info(f"Get server addresses from the `addr` argument.")
9395
else:
96+
if (
97+
self.config.experiment_name is not None
98+
and self.config.trial_name is not None
99+
):
100+
try:
101+
self.addresses = wait_llm_server_addrs(
102+
experiment_name=self.config.experiment_name,
103+
trial_name=self.config.trial_name,
104+
timeout=1,
105+
)
106+
self.logger.info(f"Get server addresses from name_resolve.")
107+
except TimeoutError:
108+
pass
109+
if not self.addresses and os.getenv("AREAL_LLM_SERVER_ADDRS"):
94110
# When addr is not provided, fallback to reading addrs from env var
95-
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
111+
self.addresses = os.environ["AREAL_LLM_SERVER_ADDRS"].split(",")
96112
if not self.addresses:
97113
raise RuntimeError(
98114
"No configured vLLM servers. Please pass in vLLM server addresses by arguments "

areal/launcher/local.py

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,14 @@
2222
)
2323
from areal.platforms import current_platform
2424
from areal.utils import logging, name_resolve, names
25-
from areal.utils.launcher import JobException, JobInfo, JobState, get_env_vars
26-
from areal.utils.network import find_free_ports, gethostip
25+
from areal.utils.launcher import (
26+
JobException,
27+
JobInfo,
28+
JobState,
29+
get_env_vars,
30+
wait_llm_server_addrs,
31+
)
32+
from areal.utils.network import find_free_ports
2733
from areal.utils.recover import check_if_recover
2834

2935
logger = logging.getLogger("Local Scheduler")
@@ -136,7 +142,9 @@ def submit_array(
136142
)
137143
c = f"{c} 2>&1 | tee -a {self.log_path_of(job_name)}"
138144
logger.info("Starting local process with command: %s", c)
139-
process = subprocess.Popen(c, shell=isinstance(c, str))
145+
process = subprocess.Popen(
146+
c, shell=isinstance(c, str), stdout=sys.stdout, stderr=sys.stdout
147+
)
140148
self._jobs[f"{job_name}/{offset + i}"] = process
141149
self._job_counter[job_name] += 1
142150

@@ -275,72 +283,64 @@ def local_main(config, run_id: int = 0):
275283
f"run_id={run_id}, is_recover_run={is_recover_run}"
276284
)
277285

278-
server_cmd = []
279-
server_addrs = []
280-
if alloc_mode.gen_backend == "sglang":
281-
base_seed = config.sglang.random_seed
282-
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
283-
ports = find_free_ports(alloc_mode.gen.dp_size * 2, port_range=(10000, 50000))
284-
host_ip = gethostip()
285-
host = "localhost" if not config.sglang.enable_metrics else host_ip
286-
for i in range(alloc_mode.gen.dp_size):
287-
config.sglang.random_seed = base_seed + i
288-
cmd = SGLangConfig.build_cmd(
289-
config.sglang,
290-
host=host,
291-
tp_size=alloc_mode.gen.tp_size,
292-
base_gpu_id=0,
293-
port=ports[i * 2],
294-
dist_init_addr=f"localhost:{ports[i*2+1]}",
295-
)
296-
server_cmd.append(cmd)
297-
server_addrs.append(f"{host}:{ports[i * 2]}")
286+
if alloc_mode.gen_backend in ("sglang", "vllm"):
287+
# Launcher should launch llm servers according to allocation mode.
288+
if alloc_mode.gen_backend == "sglang":
289+
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
290+
random_seed = config.sglang.random_seed
291+
else:
292+
config.vllm = to_structured_cfg(config.vllm, vLLMConfig)
293+
random_seed = config.vllm.seed
294+
295+
backend_spec = {
296+
"sglang": {
297+
"module": "areal.launcher.sglang_server",
298+
"seed_arg": "sglang.random_seed",
299+
"set_device_env": False,
300+
},
301+
"vllm": {
302+
"module": "areal.launcher.vllm_server",
303+
"seed_arg": "vllm.seed",
304+
"set_device_env": True, # vLLM needs `device_control_env_var` to control GPU allocation
305+
},
306+
}
307+
308+
spec = backend_spec[alloc_mode.gen_backend]
309+
310+
base_seed = random_seed
311+
seed_arg = spec["seed_arg"]
312+
module = spec["module"]
313+
server_cmd = (
314+
f"python3 -m {module} {' '.join(sys.argv[2:])} {seed_arg}={base_seed}"
315+
)
298316

299317
# Launch inference servers.
300318
launcher.submit_array(
301319
job_name="llm_server",
302320
cmd=server_cmd,
303-
count=alloc_mode.gen.dp_size,
304-
gpu=alloc_mode.gen.pp_size * alloc_mode.gen.tp_size,
321+
count=1,
322+
gpu=alloc_mode.gen.pp_size
323+
* alloc_mode.gen.tp_size
324+
* alloc_mode.gen.dp_size,
305325
env_vars=get_env_vars(
306326
config.cluster.cluster_name,
307327
config.launcher.inference_server_env_vars,
308328
),
309329
)
310-
logger.info(
311-
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
312-
)
313-
elif alloc_mode.gen_backend == "vllm":
314-
base_seed = config.vllm.seed
315-
config.vllm = to_structured_cfg(config.vllm, vLLMConfig)
316-
ports = find_free_ports(alloc_mode.gen.dp_size * 2, port_range=(10000, 50000))
317-
host = "localhost"
318-
for i in range(alloc_mode.gen.dp_size):
319-
config.vllm.seed = base_seed + i
320-
cmd = vLLMConfig.build_cmd(
321-
config.vllm,
322-
host=host,
323-
tp_size=alloc_mode.gen.tp_size,
324-
port=ports[i * 2],
325-
dist_init_addr=f"localhost:{ports[i*2+1]}",
326-
)
327-
server_cmd.append(cmd)
328-
server_addrs.append(f"{host}:{ports[i * 2]}")
329330

330-
# Launch inference servers.
331-
launcher.submit_array(
332-
job_name="llm_server",
333-
cmd=server_cmd,
334-
count=alloc_mode.gen.dp_size,
335-
gpu=alloc_mode.gen.pp_size * alloc_mode.gen.tp_size,
336-
env_vars=get_env_vars(
337-
config.cluster.cluster_name,
338-
config.launcher.inference_server_env_vars,
339-
),
331+
# Get llm server addresses by name resolve
332+
try:
333+
server_addrs = wait_llm_server_addrs(
334+
config.experiment_name,
335+
config.trial_name,
336+
n_rollout_servers=alloc_mode.gen.dp_size,
340337
)
341338
logger.info(
342339
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
343340
)
341+
except (TimeoutError, KeyboardInterrupt) as e:
342+
launcher.stop_all(signal="SIGINT")
343+
raise e
344344

345345
# Launch trainer entrypoint
346346
if alloc_mode.type_ != AllocationType.LLM_SERVER_ONLY:

areal/launcher/vllm_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def launch_server_cmd(command: str, custom_env: dict | None = None) -> subproces
3737
triton_cache_path = _env.get("TRITON_CACHE_PATH", TRITON_CACHE_PATH)
3838
unique_triton_cache_path = os.path.join(triton_cache_path, str(uuid.uuid4()))
3939
_env["TRITON_CACHE_PATH"] = unique_triton_cache_path
40+
# To avoid vllm compile cache conflict
41+
vllm_cache_path = _env.get("VLLM_CACHE_ROOT")
42+
if vllm_cache_path:
43+
_env["VLLM_CACHE_ROOT"] = os.path.join(vllm_cache_path, str(uuid.uuid4()))
4044
if custom_env is not None:
4145
_env.update(custom_env)
4246
return subprocess.Popen(

areal/utils/launcher.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
PYTORCH_KERNEL_CACHE_PATH = (
1616
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels/"
1717
)
18+
VLLM_CACHE_ROOT = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/vllm/"
1819
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton/"
1920
PYTHONPATH = os.pathsep.join(
2021
filter(
@@ -26,11 +27,13 @@
2627
)
2728
)
2829
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
30+
os.makedirs(VLLM_CACHE_ROOT, exist_ok=True)
2931
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
3032
BASE_ENVIRONS = {
3133
"TOKENIZERS_PARALLELISM": "true",
3234
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
3335
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
36+
"VLLM_CACHE_ROOT": VLLM_CACHE_ROOT,
3437
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
3538
"PYTHONPATH": PYTHONPATH,
3639
}
@@ -48,7 +51,6 @@
4851
"NCCL_DEBUG": "WARN",
4952
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
5053
}
51-
LLM_SERVER_WAIT_TIMEOUT_SECONDS = 360
5254

5355

5456
def get_env_vars(
@@ -103,7 +105,8 @@ class JobInfo:
103105
def wait_llm_server_addrs(
104106
experiment_name: str,
105107
trial_name: str,
106-
n_rollout_servers: int,
108+
n_rollout_servers: int = 1,
109+
timeout: int | None = 360,
107110
):
108111
# Get rollout nodes, find the hosts
109112
name = names.gen_servers(experiment_name, trial_name)
@@ -117,7 +120,7 @@ def wait_llm_server_addrs(
117120
break
118121

119122
time.sleep(1)
120-
if time.perf_counter() - start > LLM_SERVER_WAIT_TIMEOUT_SECONDS:
123+
if timeout is not None and time.perf_counter() - start > timeout:
121124
raise TimeoutError(
122125
f"Timeout waiting for rollout servers to be ready. "
123126
f"Expected {n_rollout_servers} servers, found {len(rollout_addrs)}."

0 commit comments

Comments
 (0)