Skip to content

Commit 076c3ba

Browse files
committed
Merge branch 'wht/feature/support_local_scheduler' of github.com:inclusionAI/AReaL into wht/feature/support_local_scheduler
2 parents 9ae69d2 + 7237882 commit 076c3ba

39 files changed

+652
-230
lines changed

README.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,16 @@ state-of-the-art 7B and 32B models for mathematical reasoning. Check out our
7171

7272
## 📚 Examples
7373

74-
| Task | Description | Performance |
75-
| ---------------------------------------------- | ------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------- |
76-
| **[Math](examples/math/)** | Mathematical problem solving (SFT, GRPO, or PPO) | TBA |
77-
| **[LoRA Math](examples/lora/)** | Math Agent Trained With LoRA | TBA |
78-
| **[VLM Math](examples/vlm/)** | CLEVR visual counting tasks | TBA |
79-
| **[Reasoning](examples/countdown/)** | Countdown numbers game with custom rewards | [Training Curve](/examples/countdown/countdown_training_curve.png) |
80-
| **[Search Agent](examples/search-agent/)** | An agent with end-to-end reasoning, search, browsing, and summarization capabilities | [ASearcher Repo](https://github.com/inclusionAI/ASearcher) |
81-
| **[Tool-Integrated Reasoning](examples/tir/)** | An agent that can invoke tools during reasoning | [TIR Example](https://github.com/inclusionAI/AReaL/tree/main/examples/tir) |
82-
| **[RLHF](examples/alignment/)** | RLHF for LLM Alignment | [RLHF Example](https://github.com/inclusionAI/AReaL/tree/main/examples/alignment) |
74+
| Task | Description | Performance |
75+
| ------------------------------------------------ | ------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------- |
76+
| **[Math](examples/math/)** | Mathematical problem solving (SFT, GRPO, or PPO) | TBA |
77+
| **[Multi-Turn Math](examples/multi-turn-math/)** | Iterative mathematical problem solving with self-correction | [Training Curve](examples/multi-turn-math/reward_curve.png) |
78+
| **[LoRA Math](examples/lora/)** | Math Agent Trained With LoRA | TBA |
79+
| **[VLM Math](examples/vlm/)** | CLEVR visual counting tasks | TBA |
80+
| **[Reasoning](examples/countdown/)** | Countdown numbers game with custom rewards | [Training Curve](/examples/countdown/countdown_training_curve.png) |
81+
| **[Search Agent](examples/search-agent/)** | An agent with end-to-end reasoning, search, browsing, and summarization capabilities | [ASearcher Repo](https://github.com/inclusionAI/ASearcher) |
82+
| **[Tool-Integrated Reasoning](examples/tir/)** | An agent that can invoke tools during reasoning | [TIR Example](https://github.com/inclusionAI/AReaL/tree/main/examples/tir) |
83+
| **[RLHF](examples/alignment/)** | RLHF for LLM Alignment | [RLHF Example](https://github.com/inclusionAI/AReaL/tree/main/examples/alignment) |
8384

8485
## 🔧 Support Matrix
8586

areal/api/cli_args.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
uvloop.install()
1111
from hydra import compose as hydra_compose
1212
from hydra import initialize as hydra_init
13+
from hydra.core.global_hydra import GlobalHydra
1314
from omegaconf import MISSING, DictConfig, OmegaConf
1415

1516
from areal.platforms import current_platform
@@ -295,7 +296,7 @@ class TrainEngineConfig:
295296
lora_alpha: int = field(default=16, metadata={"help": "lora alpha"})
296297
target_modules: List[str] = field(
297298
default_factory=list,
298-
metadata={"help": "lora target_modules. None defaults to 'all-linear'"},
299+
metadata={"help": "lora target_modules."},
299300
)
300301
peft_type: str = field(
301302
default="lora",
@@ -541,7 +542,7 @@ class SGLangConfig:
541542
random_seed: int = 1
542543
skip_tokenizer_init: bool = False
543544
disable_cuda_graph: bool = False
544-
disable_radix_cache: bool = False
545+
disable_radix_cache: bool = True
545546
disable_cuda_graph_padding: bool = False
546547
enable_nccl_nvls: bool = False
547548
disable_outlines_disk_cache: bool = False
@@ -1148,6 +1149,8 @@ def parse_cli_args(argv: List[str]):
11481149
assert config_file.exists(), f"Config file {config_file} does not exist."
11491150
# hydra only recognize relative paths
11501151
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
1152+
if GlobalHydra.instance().is_initialized():
1153+
GlobalHydra.instance().clear()
11511154
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
11521155
cfg = hydra_compose(
11531156
config_name=str(relpath.name).split(".yaml")[0],

areal/engine/base_hf_engine.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,6 @@ def parallelism_group(self) -> dist.ProcessGroup:
117117

118118
def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
119119
backend = current_platform.communication_backend
120-
if current_platform.communication_backend == "nccl":
121-
# Required by NCCL weight update group for SGLang
122-
os.environ["NCCL_CUMEM_ENABLE"] = "0"
123-
os.environ["NCCL_NVLS_ENABLE"] = "0"
124120
if not dist.is_initialized():
125121
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
126122
# NOTE: device_id **SHOULD NOT** be passed into init_process_group,
@@ -320,6 +316,7 @@ def step_lr_scheduler(self):
320316

321317
def prepare_mb_list(self, input_: Dict[str, Any]) -> MicroBatchList:
322318
assert "attention_mask" in input_ and "input_ids" in input_
319+
input_ = input_.copy()
323320

324321
if is_qwen2_vl_model(self.model_config.model_type):
325322
# Create the special t,h,w position IDs for qwen 2.5 VL

areal/engine/sglang_remote.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from areal.platforms import current_platform
2727
from areal.utils import logging, name_resolve, names
2828
from areal.utils.http import arequest_with_retry, get_default_connector
29+
from areal.utils.launcher import wait_llm_server_addrs
2930

3031
RID_CACHE_SIZE = 128
3132

@@ -85,9 +86,26 @@ def initialize(
8586

8687
if addr:
8788
self.addresses = addr if isinstance(addr, list) else [addr]
89+
self.logger.info(f"Get server addresses from the `addr` argument.")
8890
else:
91+
if (
92+
self.config.experiment_name is not None
93+
and self.config.trial_name is not None
94+
):
95+
try:
96+
self.addresses = wait_llm_server_addrs(
97+
experiment_name=self.config.experiment_name,
98+
trial_name=self.config.trial_name,
99+
timeout=1,
100+
)
101+
self.logger.info(f"Get server addresses from name_resolve.")
102+
except (TimeoutError, RuntimeError):
103+
# RuntimeError happens when name_resolve is not properly configured.
104+
pass
105+
if not self.addresses and os.getenv("AREAL_LLM_SERVER_ADDRS"):
89106
# When addr is not provided, fallback to reading addrs from env var
90-
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
107+
self.addresses = os.environ["AREAL_LLM_SERVER_ADDRS"].split(",")
108+
self.logger.info(f"Get server addresses from environment variable.")
91109
if not self.addresses:
92110
raise RuntimeError(
93111
"No configured SGLang servers. Please pass in SGLang server addresses by arguments "

areal/engine/vllm_remote.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from areal.platforms import current_platform
2727
from areal.utils import logging, name_resolve, names
2828
from areal.utils.http import arequest_with_retry, get_default_connector
29+
from areal.utils.launcher import wait_llm_server_addrs
2930

3031
RID_CACHE_SIZE = 128
3132

@@ -90,9 +91,26 @@ def initialize(
9091

9192
if addr:
9293
self.addresses = addr if isinstance(addr, list) else [addr]
94+
self.logger.info(f"Get server addresses from the `addr` argument.")
9395
else:
96+
if (
97+
self.config.experiment_name is not None
98+
and self.config.trial_name is not None
99+
):
100+
try:
101+
self.addresses = wait_llm_server_addrs(
102+
experiment_name=self.config.experiment_name,
103+
trial_name=self.config.trial_name,
104+
timeout=1,
105+
)
106+
self.logger.info(f"Get server addresses from name_resolve.")
107+
except (TimeoutError, RuntimeError):
108+
# RuntimeError happens when name_resolve is not properly configured.
109+
pass
110+
if not self.addresses and os.getenv("AREAL_LLM_SERVER_ADDRS"):
94111
# When addr is not provided, fallback to reading addrs from env var
95-
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
112+
self.addresses = os.environ["AREAL_LLM_SERVER_ADDRS"].split(",")
113+
self.logger.info(f"Get server addresses from environment variable.")
96114
if not self.addresses:
97115
raise RuntimeError(
98116
"No configured vLLM servers. Please pass in vLLM server addresses by arguments "

areal/experimental/tests/test_megatron_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
logger = logging.getLogger("MegatronEngine Test")
2424

2525
VOCAB_SIZE = 100
26-
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
26+
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/"
2727
if not os.path.exists(MODEL_PATH):
28-
MODEL_PATH = "Qwen/Qwen3-1.7B"
28+
MODEL_PATH = "Qwen/Qwen3-0.6B"
2929

3030

3131
@pytest.fixture(scope="module")

areal/experimental/tests/test_openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
EXPR_NAME = "test_openai"
1717
TRIAL_NAME = "trial_0"
18-
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-1.7B/"
18+
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/"
1919
if not os.path.exists(MODEL_PATH):
20-
MODEL_PATH = "Qwen/Qwen3-1.7B"
20+
MODEL_PATH = "Qwen/Qwen3-0.6B"
2121
PORT, DIST_PORT = network.find_free_ports(2)
2222
HOST = network.gethostip()
2323
# set a large timeout since we may need to download the model from hub

areal/experimental/tests/test_sglang_local_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030
EXPR_NAME = "test_sglang_local_engine"
3131
TRIAL_NAME = "trial_0"
32-
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
32+
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/"
3333
if not os.path.exists(MODEL_PATH):
34-
MODEL_PATH = "Qwen/Qwen2-0.5B"
34+
MODEL_PATH = "Qwen/Qwen3-0.6B"
3535

3636

3737
def build_engine_config(**kwargs):

areal/experimental/tests/torchrun/run_megatron_engine_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
from areal.utils.data import broadcast_tensor_container
2727

2828
MODEL_PATHS = {
29-
"qwen3": "/storage/openpsi/models/Qwen__Qwen3-1.7B/",
29+
"qwen3": "/storage/openpsi/models/Qwen__Qwen3-0.6B/",
3030
"qwen3moe": "/storage/openpsi/models/Qwen__Qwen3-30B-A3B/",
3131
}
3232
HF_MODEL_PATHS = {
33-
"qwen3": "Qwen/Qwen3-1.7B",
33+
"qwen3": "Qwen/Qwen3-0.6B",
3434
# TODO: switch Qwen3MoE to smaller model initialized from scratch
3535
"qwen3moe": "Qwen/Qwen3-30B-A3B",
3636
}

areal/launcher/local.py

Lines changed: 67 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,14 @@
2222
)
2323
from areal.platforms import current_platform
2424
from areal.utils import logging, name_resolve, names
25-
from areal.utils.launcher import JobException, JobInfo, JobState, get_env_vars
26-
from areal.utils.network import find_free_ports, gethostip
25+
from areal.utils.launcher import (
26+
JobException,
27+
JobInfo,
28+
JobState,
29+
get_env_vars,
30+
wait_llm_server_addrs,
31+
)
32+
from areal.utils.network import find_free_ports
2733
from areal.utils.recover import check_if_recover
2834

2935
logger = logging.getLogger("Local Scheduler")
@@ -136,7 +142,9 @@ def submit_array(
136142
)
137143
c = f"{c} 2>&1 | tee -a {self.log_path_of(job_name)}"
138144
logger.info("Starting local process with command: %s", c)
139-
process = subprocess.Popen(c, shell=isinstance(c, str))
145+
process = subprocess.Popen(
146+
c, shell=isinstance(c, str), stdout=sys.stdout, stderr=sys.stdout
147+
)
140148
self._jobs[f"{job_name}/{offset + i}"] = process
141149
self._job_counter[job_name] += 1
142150

@@ -275,72 +283,65 @@ def local_main(config, run_id: int = 0):
275283
f"run_id={run_id}, is_recover_run={is_recover_run}"
276284
)
277285

278-
server_cmd = []
279286
server_addrs = []
280-
if alloc_mode.gen_backend == "sglang":
281-
base_seed = config.sglang.random_seed
282-
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
283-
ports = find_free_ports(alloc_mode.gen.dp_size * 2, port_range=(10000, 50000))
284-
host_ip = gethostip()
285-
host = "localhost" if not config.sglang.enable_metrics else host_ip
286-
for i in range(alloc_mode.gen.dp_size):
287-
config.sglang.random_seed = base_seed + i
288-
cmd = SGLangConfig.build_cmd(
289-
config.sglang,
290-
host=host,
291-
tp_size=alloc_mode.gen.tp_size,
292-
base_gpu_id=0,
293-
port=ports[i * 2],
294-
dist_init_addr=f"localhost:{ports[i*2+1]}",
295-
)
296-
server_cmd.append(cmd)
297-
server_addrs.append(f"{host}:{ports[i * 2]}")
298-
299-
# Launch inference servers.
300-
launcher.submit_array(
301-
job_name="llm_server",
302-
cmd=server_cmd,
303-
count=alloc_mode.gen.dp_size,
304-
gpu=alloc_mode.gen.pp_size * alloc_mode.gen.tp_size,
305-
env_vars=get_env_vars(
306-
config.cluster.cluster_name,
307-
config.launcher.inference_server_env_vars,
308-
),
309-
)
310-
logger.info(
311-
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
287+
if alloc_mode.gen_backend in ("sglang", "vllm"):
288+
# Launcher should launch llm servers according to allocation mode.
289+
if alloc_mode.gen_backend == "sglang":
290+
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
291+
random_seed = config.sglang.random_seed
292+
else:
293+
config.vllm = to_structured_cfg(config.vllm, vLLMConfig)
294+
random_seed = config.vllm.seed
295+
296+
backend_spec = {
297+
"sglang": {
298+
"module": "areal.launcher.sglang_server",
299+
"seed_arg": "sglang.random_seed",
300+
"set_device_env": False,
301+
},
302+
"vllm": {
303+
"module": "areal.launcher.vllm_server",
304+
"seed_arg": "vllm.seed",
305+
"set_device_env": True, # vLLM needs `device_control_env_var` to control GPU allocation
306+
},
307+
}
308+
309+
spec = backend_spec[alloc_mode.gen_backend]
310+
311+
base_seed = random_seed
312+
seed_arg = spec["seed_arg"]
313+
module = spec["module"]
314+
server_cmd = (
315+
f"python3 -m {module} {' '.join(sys.argv[1:])} {seed_arg}={base_seed}"
312316
)
313-
elif alloc_mode.gen_backend == "vllm":
314-
base_seed = config.vllm.seed
315-
config.vllm = to_structured_cfg(config.vllm, vLLMConfig)
316-
ports = find_free_ports(alloc_mode.gen.dp_size * 2, port_range=(10000, 50000))
317-
host = "localhost"
318-
for i in range(alloc_mode.gen.dp_size):
319-
config.vllm.seed = base_seed + i
320-
cmd = vLLMConfig.build_cmd(
321-
config.vllm,
322-
host=host,
323-
tp_size=alloc_mode.gen.tp_size,
324-
port=ports[i * 2],
325-
dist_init_addr=f"localhost:{ports[i*2+1]}",
326-
)
327-
server_cmd.append(cmd)
328-
server_addrs.append(f"{host}:{ports[i * 2]}")
329317

330318
# Launch inference servers.
331319
launcher.submit_array(
332320
job_name="llm_server",
333321
cmd=server_cmd,
334-
count=alloc_mode.gen.dp_size,
335-
gpu=alloc_mode.gen.pp_size * alloc_mode.gen.tp_size,
322+
count=1,
323+
gpu=alloc_mode.gen.pp_size
324+
* alloc_mode.gen.tp_size
325+
* alloc_mode.gen.dp_size,
336326
env_vars=get_env_vars(
337327
config.cluster.cluster_name,
338328
config.launcher.inference_server_env_vars,
339329
),
340330
)
341-
logger.info(
342-
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
343-
)
331+
332+
# Get llm server addresses by name resolve
333+
try:
334+
server_addrs = wait_llm_server_addrs(
335+
config.experiment_name,
336+
config.trial_name,
337+
n_rollout_servers=alloc_mode.gen.dp_size,
338+
)
339+
logger.info(
340+
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
341+
)
342+
except (TimeoutError, KeyboardInterrupt) as e:
343+
launcher.stop_all(signal="SIGINT")
344+
raise e
344345

345346
# Launch trainer entrypoint
346347
if alloc_mode.type_ != AllocationType.LLM_SERVER_ONLY:
@@ -349,6 +350,14 @@ def local_main(config, run_id: int = 0):
349350
nprocs = 1
350351
else:
351352
gpu = nprocs = alloc_mode.train.world_size
353+
_env_vars = dict(
354+
AREAL_LLM_SERVER_ADDRS=",".join(server_addrs),
355+
AREAL_RECOVER_RUN=str(int(is_recover_run)),
356+
)
357+
if alloc_mode.gen_backend == "sglang":
358+
# Required by NCCL weight update group.
359+
_env_vars["NCCL_CUMEM_ENABLE"] = "0"
360+
_env_vars["NCCL_NVLS_ENABLE"] = "0"
352361
launcher.submit(
353362
job_name="trainer",
354363
cmd=f"torchrun --nnodes 1 --nproc-per-node {nprocs} --master-addr localhost --master-port {find_free_ports(1, (10000, 50000))[0]} {' '.join(sys.argv[1:])}",
@@ -358,8 +367,7 @@ def local_main(config, run_id: int = 0):
358367
config.cluster.cluster_name,
359368
config.launcher.trainer_env_vars,
360369
),
361-
AREAL_LLM_SERVER_ADDRS=",".join(server_addrs),
362-
AREAL_RECOVER_RUN=str(int(is_recover_run)),
370+
**_env_vars,
363371
),
364372
)
365373

0 commit comments

Comments
 (0)