Skip to content

Commit 6138e3a

Browse files
authored
[Feat] support vllm with slurm launcher (#404)
* [Feat] support vllm with slurm launcher
1 parent 8f38fe3 commit 6138e3a

File tree

4 files changed

+113
-52
lines changed

4 files changed

+113
-52
lines changed

areal/api/cli_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ class vLLMConfig:
455455
skip_tokenizer_init: bool = False
456456
enforce_eager: bool = True
457457
dtype: str = "bfloat16"
458-
distributed_executor_backend = "mp"
458+
distributed_executor_backend: str = "mp"
459459
# original
460460
max_num_seqs: int = 256
461461
# kv_cache_type: str = "auto"
@@ -479,6 +479,7 @@ class vLLMConfig:
479479
"areal.thirdparty.vllm.vllm_worker_extension.VLLMWorkerExtension"
480480
)
481481
enable_sleep_mode: bool = False
482+
uvicorn_log_level: str = "warning"
482483

483484
@staticmethod
484485
def build_args(

areal/engine/vllm_remote.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,6 @@ def update_weights_from_distributed(
432432
]
433433

434434
async def _fn():
435-
tik = time.perf_counter()
436435
if init_group:
437436
await asyncio.gather(
438437
*[
@@ -472,8 +471,6 @@ async def _fn():
472471
]
473472
)
474473

475-
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
476-
477474
return uvloop.run(_fn())
478475

479476

areal/launcher/slurm.py

Lines changed: 93 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
SGLangConfig,
1717
parse_cli_args,
1818
to_structured_cfg,
19+
vLLMConfig,
1920
)
2021
from areal.platforms import current_platform
2122
from areal.utils import logging, name_resolve, names
@@ -431,58 +432,111 @@ def slurm_main(config, run_id: int = 0):
431432
n_gpus_per_node = config.cluster.n_gpus_per_node
432433
allocation_mode = config.allocation_mode
433434
allocation_mode = AllocationMode.from_str(allocation_mode)
434-
sglang_cmds = []
435-
sglang_addrs = []
436-
n_sglang_nodes = 0
437-
if allocation_mode.gen_backend == "sglang":
438-
# Launcher should launch SGLang servers according to allocation mode.
439-
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
440-
n_sglang_servers = allocation_mode.gen.dp_size
441-
n_sglang_nodes = allocation_mode.gen.world_size // n_gpus_per_node
442-
node_group_size = max(1, allocation_mode.gen_instance_size // n_gpus_per_node)
443-
n_servers_per_node = max(n_sglang_servers // n_sglang_nodes, 1)
444-
445-
cross_nodes = allocation_mode.gen_instance_size > n_gpus_per_node
446-
env_vars = get_env_vars(
447-
config.cluster.cluster_name,
448-
config.launcher.inference_server_env_vars,
449-
)
450-
env_vars = [copy.deepcopy(env_vars) for _ in range(n_sglang_nodes)]
451-
base_seed = config.sglang.random_seed
452-
sglang_server_cmd_template = f"python3 -m areal.launcher.sglang_server {' '.join(sys.argv[2:])} sglang.random_seed={{seed}}"
453-
for i in range(n_sglang_nodes):
454-
sglang_cmd = sglang_server_cmd_template.format(
455-
seed=base_seed + i * n_servers_per_node
435+
n_backend_nodes = 0
436+
437+
if allocation_mode.gen_backend in ("sglang", "vllm"):
438+
# Launcher should launch llm servers according to allocation mode.
439+
if allocation_mode.gen_backend == "sglang":
440+
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
441+
random_seed = config.sglang.random_seed
442+
else:
443+
config.vllm = to_structured_cfg(config.vllm, vLLMConfig)
444+
random_seed = config.vllm.seed
445+
446+
backend_spec = {
447+
"sglang": {
448+
"module": "areal.launcher.sglang_server",
449+
"seed_arg": "sglang.random_seed",
450+
"prefix": "AREAL_SGLANG",
451+
"set_device_env": False,
452+
},
453+
"vllm": {
454+
"module": "areal.launcher.vllm_server",
455+
"seed_arg": "vllm.seed",
456+
"prefix": "AREAL_VLLM",
457+
"set_device_env": True, # vLLM needs `device_control_env_var` to control GPU allocation
458+
},
459+
}
460+
461+
def _build_llm_server_plan(backend: str, spec: Dict):
462+
# Returns: cmds, env_vars_list, n_nodes, n_servers
463+
464+
if backend not in backend_spec:
465+
raise NotImplementedError(f"Unknown backend: {backend}")
466+
467+
spec = backend_spec[backend]
468+
469+
n_backend_servers = allocation_mode.gen.dp_size
470+
n_backend_nodes = allocation_mode.gen.world_size // n_gpus_per_node
471+
node_group_size = max(
472+
1, allocation_mode.gen_instance_size // n_gpus_per_node
473+
)
474+
n_servers_per_node = max(n_backend_servers // n_backend_nodes, 1)
475+
476+
cross_nodes = allocation_mode.gen_instance_size > n_gpus_per_node
477+
base_env_bars = get_env_vars(
478+
config.cluster.cluster_name,
479+
config.launcher.inference_server_env_vars,
480+
)
481+
if spec["set_device_env"]:
482+
base_env_bars[current_platform.device_control_env_var] = ",".join(
483+
list(map(str, range(n_gpus_per_node)))
484+
)
485+
env_list = [copy.deepcopy(base_env_bars) for _ in range(n_backend_nodes)]
486+
487+
base_seed = random_seed
488+
seed_arg = spec["seed_arg"]
489+
module = spec["module"]
490+
backend_server_cmd_template = (
491+
f"python3 -m {module} {' '.join(sys.argv[2:])} {seed_arg}={{seed}}"
456492
)
457-
sglang_cmds.append(sglang_cmd)
458-
if cross_nodes:
459-
# master_addrs and master_ports are the IP addresses and free ports of the all nodes in the job array, obtained in the SBATCH script.
460-
env_vars[i] |= dict(
461-
AREAL_SGLANG_MULTI_NODE_RANK=i % node_group_size,
462-
AREAL_SGLANG_MULTI_NODE_MASTER_ADDR=f"${{master_addrs[{i // node_group_size * node_group_size}]}}",
463-
AREAL_SGLANG_MULTI_NODE_MASTER_PORT=f"${{master_ports[{i // node_group_size * node_group_size}]}}",
493+
494+
backend_cmds = []
495+
for i in range(n_backend_nodes):
496+
backend_cmd = backend_server_cmd_template.format(
497+
seed=base_seed + i * n_servers_per_node
464498
)
499+
backend_cmds.append(backend_cmd)
500+
if cross_nodes:
501+
# master_addrs and master_ports are the IP addresses and free ports of the all nodes in the job array, obtained in the SBATCH script.
502+
prefix = spec["prefix"]
503+
env_list[i] |= dict(
504+
**{
505+
f"{prefix}_MULTI_NODE_RANK": i % node_group_size,
506+
f"{prefix}_MULTI_NODE_MASTER_ADDR": f"${{master_addrs[{i // node_group_size * node_group_size}]}}",
507+
f"{prefix}_MULTI_NODE_MASTER_PORT": f"${{master_ports[{i // node_group_size * node_group_size}]}}",
508+
}
509+
)
510+
511+
return backend_cmds, env_list, n_backend_nodes, n_backend_servers
512+
513+
backend_cmds, env_list, n_backend_nodes, n_backend_servers = (
514+
_build_llm_server_plan(
515+
allocation_mode.gen_backend,
516+
random_seed,
517+
)
518+
)
465519

466520
launcher.submit_array(
467521
job_name="llm_server",
468-
cmd=sglang_cmds,
469-
count=n_sglang_nodes,
470-
nodes=n_sglang_nodes,
471-
n_gpus_per_node=config.cluster.n_gpus_per_node,
522+
cmd=backend_cmds,
523+
count=n_backend_nodes,
524+
nodes=n_backend_nodes,
525+
n_gpus_per_node=n_gpus_per_node,
472526
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
473527
* n_gpus_per_node,
474528
mem_per_task=config.launcher.inference_server_mem_per_gpu * n_gpus_per_node,
475529
srun_additional_args=config.launcher.slurm.srun_additional_args,
476530
container_image=config.launcher.slurm.inference_server_image,
477531
container_mounts=config.launcher.slurm.mount,
478-
env_vars=env_vars,
532+
env_vars=env_list,
479533
)
480-
# Get SGLang server addresses by name resolve
534+
# Get llm server addresses by name resolve
481535
try:
482-
sglang_addrs = wait_llm_server_addrs(
536+
llm_addrs = wait_llm_server_addrs(
483537
config.experiment_name,
484538
config.trial_name,
485-
n_sglang_servers,
539+
n_backend_servers,
486540
)
487541
except (TimeoutError, KeyboardInterrupt) as e:
488542
launcher.stop_all(force=True)
@@ -492,7 +546,7 @@ def slurm_main(config, run_id: int = 0):
492546
trainer_n_nodes = 1
493547
gpus_per_node = 0
494548
else:
495-
trainer_n_nodes = n_nodes - n_sglang_nodes
549+
trainer_n_nodes = n_nodes - n_backend_nodes
496550
gpus_per_node = config.cluster.n_gpus_per_node
497551

498552
# Here $head_node_ip is the IP address of the first node in the job array.
@@ -534,7 +588,7 @@ def slurm_main(config, run_id: int = 0):
534588
config.cluster.cluster_name,
535589
config.launcher.trainer_env_vars,
536590
),
537-
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
591+
AREAL_LLM_SERVER_ADDRS=",".join(llm_addrs),
538592
AREAL_RECOVER_RUN=str(int(is_recover_run)),
539593
),
540594
)

areal/launcher/vllm_server.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,21 @@
2424
logger = logging.getLogger("vLLMServer Wrapper")
2525

2626

27-
def launch_server_cmd(command: str) -> subprocess.Popen:
27+
def launch_server_cmd(command: str, custom_env: dict | None = None) -> subprocess.Popen:
2828
"""
2929
Execute a shell command and return its process handle.
3030
"""
3131
# Replace newline continuations and split the command string.
3232
command = command.replace("\\\n", " ").replace("\\", " ")
33+
logger.info(f"Launch command: {command}")
3334
parts = command.split()
3435
_env = os.environ.copy()
3536
# To avoid DirectoryNotEmpty error caused by triton
3637
triton_cache_path = _env.get("TRITON_CACHE_PATH", TRITON_CACHE_PATH)
3738
unique_triton_cache_path = os.path.join(triton_cache_path, str(uuid.uuid4()))
3839
_env["TRITON_CACHE_PATH"] = unique_triton_cache_path
40+
if custom_env is not None:
41+
_env.update(custom_env)
3942
return subprocess.Popen(
4043
parts,
4144
text=True,
@@ -94,13 +97,10 @@ def run(self):
9497
device_control_env_var = current_platform.device_control_env_var
9598
if device_control_env_var in os.environ:
9699
visible = os.getenv(device_control_env_var).split(",")
97-
ordered = ",".join(sorted(visible, key=int))
98-
os.environ[device_control_env_var] = ordered
99100
n_visible_devices = len(visible)
100101
n_servers_per_proc = max(1, n_visible_devices // gpus_per_server)
101-
server_idx_offset = int(visible[0]) // gpus_per_server
102+
server_idx_offset = min(list(map(int, visible))) // gpus_per_server
102103
else:
103-
n_visible_devices = self.n_gpus_per_node
104104
n_servers_per_proc = n_servers_per_node
105105
server_idx_offset = 0
106106

@@ -109,6 +109,7 @@ def run(self):
109109
ports_per_server = 40000 // n_servers_per_node
110110
launch_server_args = []
111111
server_addresses = []
112+
base_random_seed = self.config.seed
112113
for server_local_idx in range(
113114
server_idx_offset, server_idx_offset + n_servers_per_proc
114115
):
@@ -121,15 +122,21 @@ def run(self):
121122
dist_init_addr = f"localhost:{dist_init_port}"
122123
host_ip = gethostip()
123124

124-
(server_local_idx - server_idx_offset) * gpus_per_server
125+
base_gpu_id = (server_local_idx - server_idx_offset) * gpus_per_server
126+
custom_env = {
127+
device_control_env_var: ",".join(
128+
map(str, range(base_gpu_id, base_gpu_id + gpus_per_server))
129+
)
130+
}
131+
self.config.seed = base_random_seed + server_local_idx
125132
cmd = vLLMConfig.build_cmd(
126133
self.config,
127134
tp_size=self.allocation_mode.gen.tp_size,
128135
host=host_ip,
129136
port=server_port,
130137
dist_init_addr=dist_init_addr,
131138
)
132-
launch_server_args.append((cmd, host_ip, server_port))
139+
launch_server_args.append((cmd, host_ip, server_port, custom_env))
133140
server_addresses.append(f"http://{host_ip}:{server_port}")
134141

135142
with ThreadPoolExecutor(max_workers=n_servers_per_proc) as executor:
@@ -159,8 +166,10 @@ def run(self):
159166

160167
time.sleep(1)
161168

162-
def launch_one_server(self, cmd, host_ip, server_port):
163-
server_process = launch_server_cmd(cmd)
169+
def launch_one_server(
170+
self, cmd: str, host_ip: str, server_port: int, custom_env: dict | None = None
171+
):
172+
server_process = launch_server_cmd(cmd, custom_env)
164173
wait_for_server(f"http://{host_ip}:{server_port}")
165174
name = names.gen_servers(self.experiment_name, self.trial_name)
166175
name_resolve.add_subentry(name, f"{host_ip}:{server_port}")

0 commit comments

Comments
 (0)