1616 SGLangConfig ,
1717 parse_cli_args ,
1818 to_structured_cfg ,
19+ vLLMConfig ,
1920)
2021from areal .platforms import current_platform
2122from areal .utils import logging , name_resolve , names
@@ -431,58 +432,111 @@ def slurm_main(config, run_id: int = 0):
431432 n_gpus_per_node = config .cluster .n_gpus_per_node
432433 allocation_mode = config .allocation_mode
433434 allocation_mode = AllocationMode .from_str (allocation_mode )
434- sglang_cmds = []
435- sglang_addrs = []
436- n_sglang_nodes = 0
437- if allocation_mode .gen_backend == "sglang" :
438- # Launcher should launch SGLang servers according to allocation mode.
439- config .sglang = to_structured_cfg (config .sglang , SGLangConfig )
440- n_sglang_servers = allocation_mode .gen .dp_size
441- n_sglang_nodes = allocation_mode .gen .world_size // n_gpus_per_node
442- node_group_size = max (1 , allocation_mode .gen_instance_size // n_gpus_per_node )
443- n_servers_per_node = max (n_sglang_servers // n_sglang_nodes , 1 )
444-
445- cross_nodes = allocation_mode .gen_instance_size > n_gpus_per_node
446- env_vars = get_env_vars (
447- config .cluster .cluster_name ,
448- config .launcher .inference_server_env_vars ,
449- )
450- env_vars = [copy .deepcopy (env_vars ) for _ in range (n_sglang_nodes )]
451- base_seed = config .sglang .random_seed
452- sglang_server_cmd_template = f"python3 -m areal.launcher.sglang_server { ' ' .join (sys .argv [2 :])} sglang.random_seed={{seed}}"
453- for i in range (n_sglang_nodes ):
454- sglang_cmd = sglang_server_cmd_template .format (
455- seed = base_seed + i * n_servers_per_node
435+ n_backend_nodes = 0
436+
437+ if allocation_mode .gen_backend in ("sglang" , "vllm" ):
438+ # Launcher should launch llm servers according to allocation mode.
439+ if allocation_mode .gen_backend == "sglang" :
440+ config .sglang = to_structured_cfg (config .sglang , SGLangConfig )
441+ random_seed = config .sglang .random_seed
442+ else :
443+ config .vllm = to_structured_cfg (config .vllm , vLLMConfig )
444+ random_seed = config .vllm .seed
445+
446+ backend_spec = {
447+ "sglang" : {
448+ "module" : "areal.launcher.sglang_server" ,
449+ "seed_arg" : "sglang.random_seed" ,
450+ "prefix" : "AREAL_SGLANG" ,
451+ "set_device_env" : False ,
452+ },
453+ "vllm" : {
454+ "module" : "areal.launcher.vllm_server" ,
455+ "seed_arg" : "vllm.seed" ,
456+ "prefix" : "AREAL_VLLM" ,
457+ "set_device_env" : True , # vLLM needs `device_control_env_var` to control GPU allocation
458+ },
459+ }
460+
461+ def _build_llm_server_plan (backend : str , spec : Dict ):
462+ # Returns: cmds, env_vars_list, n_nodes, n_servers
463+
464+ if backend not in backend_spec :
465+ raise NotImplementedError (f"Unknown backend: { backend } " )
466+
467+ spec = backend_spec [backend ]
468+
469+ n_backend_servers = allocation_mode .gen .dp_size
470+ n_backend_nodes = allocation_mode .gen .world_size // n_gpus_per_node
471+ node_group_size = max (
472+ 1 , allocation_mode .gen_instance_size // n_gpus_per_node
473+ )
474+ n_servers_per_node = max (n_backend_servers // n_backend_nodes , 1 )
475+
476+ cross_nodes = allocation_mode .gen_instance_size > n_gpus_per_node
477+ base_env_bars = get_env_vars (
478+ config .cluster .cluster_name ,
479+ config .launcher .inference_server_env_vars ,
480+ )
481+ if spec ["set_device_env" ]:
482+ base_env_bars [current_platform .device_control_env_var ] = "," .join (
483+ list (map (str , range (n_gpus_per_node )))
484+ )
485+ env_list = [copy .deepcopy (base_env_bars ) for _ in range (n_backend_nodes )]
486+
487+ base_seed = random_seed
488+ seed_arg = spec ["seed_arg" ]
489+ module = spec ["module" ]
490+ backend_server_cmd_template = (
491+ f"python3 -m { module } { ' ' .join (sys .argv [2 :])} { seed_arg } ={{seed}}"
456492 )
457- sglang_cmds .append (sglang_cmd )
458- if cross_nodes :
459- # master_addrs and master_ports are the IP addresses and free ports of the all nodes in the job array, obtained in the SBATCH script.
460- env_vars [i ] |= dict (
461- AREAL_SGLANG_MULTI_NODE_RANK = i % node_group_size ,
462- AREAL_SGLANG_MULTI_NODE_MASTER_ADDR = f"${{master_addrs[{ i // node_group_size * node_group_size } ]}}" ,
463- AREAL_SGLANG_MULTI_NODE_MASTER_PORT = f"${{master_ports[{ i // node_group_size * node_group_size } ]}}" ,
493+
494+ backend_cmds = []
495+ for i in range (n_backend_nodes ):
496+ backend_cmd = backend_server_cmd_template .format (
497+ seed = base_seed + i * n_servers_per_node
464498 )
499+ backend_cmds .append (backend_cmd )
500+ if cross_nodes :
501+ # master_addrs and master_ports are the IP addresses and free ports of the all nodes in the job array, obtained in the SBATCH script.
502+ prefix = spec ["prefix" ]
503+ env_list [i ] |= dict (
504+ ** {
505+ f"{ prefix } _MULTI_NODE_RANK" : i % node_group_size ,
506+ f"{ prefix } _MULTI_NODE_MASTER_ADDR" : f"${{master_addrs[{ i // node_group_size * node_group_size } ]}}" ,
507+ f"{ prefix } _MULTI_NODE_MASTER_PORT" : f"${{master_ports[{ i // node_group_size * node_group_size } ]}}" ,
508+ }
509+ )
510+
511+ return backend_cmds , env_list , n_backend_nodes , n_backend_servers
512+
513+ backend_cmds , env_list , n_backend_nodes , n_backend_servers = (
514+ _build_llm_server_plan (
515+ allocation_mode .gen_backend ,
516+ random_seed ,
517+ )
518+ )
465519
466520 launcher .submit_array (
467521 job_name = "llm_server" ,
468- cmd = sglang_cmds ,
469- count = n_sglang_nodes ,
470- nodes = n_sglang_nodes ,
471- n_gpus_per_node = config . cluster . n_gpus_per_node ,
522+ cmd = backend_cmds ,
523+ count = n_backend_nodes ,
524+ nodes = n_backend_nodes ,
525+ n_gpus_per_node = n_gpus_per_node ,
472526 cpus_per_task = config .launcher .inference_server_cpus_per_gpu
473527 * n_gpus_per_node ,
474528 mem_per_task = config .launcher .inference_server_mem_per_gpu * n_gpus_per_node ,
475529 srun_additional_args = config .launcher .slurm .srun_additional_args ,
476530 container_image = config .launcher .slurm .inference_server_image ,
477531 container_mounts = config .launcher .slurm .mount ,
478- env_vars = env_vars ,
532+ env_vars = env_list ,
479533 )
480- # Get SGLang server addresses by name resolve
534+ # Get llm server addresses by name resolve
481535 try :
482- sglang_addrs = wait_llm_server_addrs (
536+ llm_addrs = wait_llm_server_addrs (
483537 config .experiment_name ,
484538 config .trial_name ,
485- n_sglang_servers ,
539+ n_backend_servers ,
486540 )
487541 except (TimeoutError , KeyboardInterrupt ) as e :
488542 launcher .stop_all (force = True )
@@ -492,7 +546,7 @@ def slurm_main(config, run_id: int = 0):
492546 trainer_n_nodes = 1
493547 gpus_per_node = 0
494548 else :
495- trainer_n_nodes = n_nodes - n_sglang_nodes
549+ trainer_n_nodes = n_nodes - n_backend_nodes
496550 gpus_per_node = config .cluster .n_gpus_per_node
497551
498552 # Here $head_node_ip is the IP address of the first node in the job array.
@@ -534,7 +588,7 @@ def slurm_main(config, run_id: int = 0):
534588 config .cluster .cluster_name ,
535589 config .launcher .trainer_env_vars ,
536590 ),
537- AREAL_LLM_SERVER_ADDRS = "," .join (sglang_addrs ),
591+ AREAL_LLM_SERVER_ADDRS = "," .join (llm_addrs ),
538592 AREAL_RECOVER_RUN = str (int (is_recover_run )),
539593 ),
540594 )
0 commit comments