Skip to content

Commit 4f420e6

Browse files
fix unit tests for the next release (#418)
* fix all unit tests * Use local models and datasets * Move multi-turn math example outside the `math` folder and update README * Update README.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * . * add more tests for examples * search agent ut pending * Use vllm/sglang server wrapper for local launching and fix an NCCL issue with vllm * fix * . * . * run pre-commit * . * fix sglang nccl weight update env var * . * fix --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 3c0bd3d commit 4f420e6

30 files changed

+500
-131
lines changed

areal/api/cli_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
uvloop.install()
1111
from hydra import compose as hydra_compose
1212
from hydra import initialize as hydra_init
13+
from hydra.core.global_hydra import GlobalHydra
1314
from omegaconf import MISSING, DictConfig, OmegaConf
1415

1516
from areal.platforms import current_platform
@@ -1148,6 +1149,8 @@ def parse_cli_args(argv: List[str]):
11481149
assert config_file.exists(), f"Config file {config_file} does not exist."
11491150
# hydra only recognize relative paths
11501151
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
1152+
if GlobalHydra.instance().is_initialized():
1153+
GlobalHydra.instance().clear()
11511154
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
11521155
cfg = hydra_compose(
11531156
config_name=str(relpath.name).split(".yaml")[0],

areal/engine/base_hf_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ def step_lr_scheduler(self):
316316

317317
def prepare_mb_list(self, input_: Dict[str, Any]) -> MicroBatchList:
318318
assert "attention_mask" in input_ and "input_ids" in input_
319+
input_ = input_.copy()
319320

320321
if is_qwen2_vl_model(self.model_config.model_type):
321322
# Create the special t,h,w position IDs for qwen 2.5 VL

areal/engine/sglang_remote.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@
3434
class RemoteSGLangEngine(InferenceEngine):
3535

3636
def __init__(self, config: InferenceEngineConfig):
37-
if current_platform.communication_backend == "nccl":
38-
# Required by NCCL weight update group.
39-
os.environ["NCCL_CUMEM_ENABLE"] = "0"
40-
os.environ["NCCL_NVLS_ENABLE"] = "0"
41-
4237
self.config = config
4338

4439
self.rid_to_address = {}
@@ -102,11 +97,13 @@ def initialize(
10297
timeout=1,
10398
)
10499
self.logger.info(f"Get server addresses from name_resolve.")
105-
except TimeoutError:
100+
except (TimeoutError, RuntimeError):
101+
# RuntimeError happens when name_resolve is not properly configured.
106102
pass
107103
if not self.addresses and os.getenv("AREAL_LLM_SERVER_ADDRS"):
108104
# When addr is not provided, fallback to reading addrs from env var
109105
self.addresses = os.environ["AREAL_LLM_SERVER_ADDRS"].split(",")
106+
self.logger.info(f"Get server addresses from environment variable.")
110107
if not self.addresses:
111108
raise RuntimeError(
112109
"No configured SGLang servers. Please pass in SGLang server addresses by arguments "

areal/engine/vllm_remote.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,13 @@ def initialize(
104104
timeout=1,
105105
)
106106
self.logger.info(f"Get server addresses from name_resolve.")
107-
except TimeoutError:
107+
except (TimeoutError, RuntimeError):
108+
# RuntimeError happens when name_resolve is not properly configured.
108109
pass
109110
if not self.addresses and os.getenv("AREAL_LLM_SERVER_ADDRS"):
110111
# When addr is not provided, fallback to reading addrs from env var
111112
self.addresses = os.environ["AREAL_LLM_SERVER_ADDRS"].split(",")
113+
self.logger.info(f"Get server addresses from environment variable.")
112114
if not self.addresses:
113115
raise RuntimeError(
114116
"No configured vLLM servers. Please pass in vLLM server addresses by arguments "

areal/experimental/tests/test_megatron_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
logger = logging.getLogger("MegatronEngine Test")
2424

2525
VOCAB_SIZE = 100
26-
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
26+
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/"
2727
if not os.path.exists(MODEL_PATH):
28-
MODEL_PATH = "Qwen/Qwen3-1.7B"
28+
MODEL_PATH = "Qwen/Qwen3-0.6B"
2929

3030

3131
@pytest.fixture(scope="module")

areal/experimental/tests/test_openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
EXPR_NAME = "test_openai"
1717
TRIAL_NAME = "trial_0"
18-
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-1.7B/"
18+
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/"
1919
if not os.path.exists(MODEL_PATH):
20-
MODEL_PATH = "Qwen/Qwen3-1.7B"
20+
MODEL_PATH = "Qwen/Qwen3-0.6B"
2121
PORT, DIST_PORT = network.find_free_ports(2)
2222
HOST = network.gethostip()
2323
# set a large timeout since we may need to download the model from hub

areal/experimental/tests/test_sglang_local_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030
EXPR_NAME = "test_sglang_local_engine"
3131
TRIAL_NAME = "trial_0"
32-
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
32+
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/"
3333
if not os.path.exists(MODEL_PATH):
34-
MODEL_PATH = "Qwen/Qwen2-0.5B"
34+
MODEL_PATH = "Qwen/Qwen3-0.6B"
3535

3636

3737
def build_engine_config(**kwargs):

areal/experimental/tests/torchrun/run_megatron_engine_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
from areal.utils.data import broadcast_tensor_container
2727

2828
MODEL_PATHS = {
29-
"qwen3": "/storage/openpsi/models/Qwen__Qwen3-1.7B/",
29+
"qwen3": "/storage/openpsi/models/Qwen__Qwen3-0.6B/",
3030
"qwen3moe": "/storage/openpsi/models/Qwen__Qwen3-30B-A3B/",
3131
}
3232
HF_MODEL_PATHS = {
33-
"qwen3": "Qwen/Qwen3-1.7B",
33+
"qwen3": "Qwen/Qwen3-0.6B",
3434
# TODO: switch Qwen3MoE to smaller model initialized from scratch
3535
"qwen3moe": "Qwen/Qwen3-30B-A3B",
3636
}

areal/launcher/local.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def local_main(config, run_id: int = 0):
283283
f"run_id={run_id}, is_recover_run={is_recover_run}"
284284
)
285285

286+
server_addrs = []
286287
if alloc_mode.gen_backend in ("sglang", "vllm"):
287288
# Launcher should launch llm servers according to allocation mode.
288289
if alloc_mode.gen_backend == "sglang":
@@ -328,19 +329,19 @@ def local_main(config, run_id: int = 0):
328329
),
329330
)
330331

331-
# Get llm server addresses by name resolve
332-
try:
333-
server_addrs = wait_llm_server_addrs(
334-
config.experiment_name,
335-
config.trial_name,
336-
n_rollout_servers=alloc_mode.gen.dp_size,
337-
)
338-
logger.info(
339-
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
340-
)
341-
except (TimeoutError, KeyboardInterrupt) as e:
342-
launcher.stop_all(signal="SIGINT")
343-
raise e
332+
# Get llm server addresses by name resolve
333+
try:
334+
server_addrs = wait_llm_server_addrs(
335+
config.experiment_name,
336+
config.trial_name,
337+
n_rollout_servers=alloc_mode.gen.dp_size,
338+
)
339+
logger.info(
340+
f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}"
341+
)
342+
except (TimeoutError, KeyboardInterrupt) as e:
343+
launcher.stop_all(signal="SIGINT")
344+
raise e
344345

345346
# Launch trainer entrypoint
346347
if alloc_mode.type_ != AllocationType.LLM_SERVER_ONLY:
@@ -349,6 +350,14 @@ def local_main(config, run_id: int = 0):
349350
nprocs = 1
350351
else:
351352
gpu = nprocs = alloc_mode.train.world_size
353+
_env_vars = dict(
354+
AREAL_LLM_SERVER_ADDRS=",".join(server_addrs),
355+
AREAL_RECOVER_RUN=str(int(is_recover_run)),
356+
)
357+
if alloc_mode.gen_backend == "sglang":
358+
# Required by NCCL weight update group.
359+
_env_vars["NCCL_CUMEM_ENABLE"] = "0"
360+
_env_vars["NCCL_NVLS_ENABLE"] = "0"
352361
launcher.submit(
353362
job_name="trainer",
354363
cmd=f"torchrun --nnodes 1 --nproc-per-node {nprocs} --master-addr localhost --master-port {find_free_ports(1, (10000, 50000))[0]} {' '.join(sys.argv[1:])}",
@@ -358,8 +367,7 @@ def local_main(config, run_id: int = 0):
358367
config.cluster.cluster_name,
359368
config.launcher.trainer_env_vars,
360369
),
361-
AREAL_LLM_SERVER_ADDRS=",".join(server_addrs),
362-
AREAL_RECOVER_RUN=str(int(is_recover_run)),
370+
**_env_vars,
363371
),
364372
)
365373

areal/launcher/ray.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,14 @@ def torch_env_hook(n_tasks: int, placement_group: PlacementGroup) -> List[Dict]:
534534
)
535535
return env_vars
536536

537+
_env_vars = dict(
538+
AREAL_LLM_SERVER_ADDRS=",".join(llm_addrs),
539+
AREAL_RECOVER_RUN=str(int(is_recover_run)),
540+
)
541+
if allocation_mode.gen_backend == "sglang":
542+
# Required by NCCL weight update group.
543+
_env_vars["NCCL_CUMEM_ENABLE"] = "0"
544+
_env_vars["NCCL_NVLS_ENABLE"] = "0"
537545
launcher.submit_array(
538546
job_name="trainer",
539547
file_path=trainer_entry_point,
@@ -549,8 +557,7 @@ def torch_env_hook(n_tasks: int, placement_group: PlacementGroup) -> List[Dict]:
549557
config.cluster.cluster_name,
550558
config.launcher.trainer_env_vars,
551559
),
552-
AREAL_LLM_SERVER_ADDRS=",".join(llm_addrs),
553-
AREAL_RECOVER_RUN=str(int(is_recover_run)),
560+
**_env_vars,
554561
),
555562
env_hook=partial(torch_env_hook, trainer_n_nodes * n_gpus_per_node),
556563
)

0 commit comments

Comments
 (0)