Skip to content

Commit 9694495

Browse files
楚财峯回
authored andcommitted
PullRequest: 970 onpolicy sync training
Merge branch chucai.dzq/update-weights-debug of [email protected]:inclusionAI/AReaL.git into asystem/gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/970 Reviewed-by: 峯回 <[email protected]> * onpolicy sync training
1 parent 1503b06 commit 9694495

File tree

14 files changed

+631
-250
lines changed

14 files changed

+631
-250
lines changed

areal/api/io_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class SaveLoadMeta:
206206
processor: Optional["AutoProcessor"] = None
207207
base_model_path: str | None = None
208208
naive_distributed: bool = False
209-
209+
global_step: int | None = None
210210

211211
@dataclass
212212
class RolloutStat:

areal/examples/configs/my001/on_policy.yaml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
experiment_name: mini-model
2-
trial_name: test
3-
allocation_mode: "sglang:d2t4p1+d2t1p4"
1+
experiment_name: mini-model-rebased
2+
trial_name: on-policy
3+
allocation_mode: "sglang:d8t4p1+d8t1p4"
44
seed: 42
5-
total_train_epochs: 1
6-
total_train_steps: 2
5+
total_train_epochs: 10
6+
total_train_steps: 1145
77
weight_update_type: "astate"
88
enable_colocate_mode: false
99
async_training: false
@@ -15,7 +15,7 @@ train_dataset:
1515
path: "/storage/dataset/nlp/areal/moe_lite_math_0527_merge_train_areal.jsonl"
1616
shuffle: true
1717
max_length: 1024
18-
batch_size: 8
18+
batch_size: 64
1919
type: "rl"
2020

2121
scheduler:
@@ -37,13 +37,13 @@ stats_logger:
3737
path: "/home/admin/logs/tfevent/asystem"
3838

3939
gconfig:
40-
n_samples: 1
40+
n_samples: 8
4141
min_new_tokens: 0
4242
# NOTE!!
4343
# Due to the limitations of sglang, max_new_tokens + max_prompt_len must be less than the model's context_len (set in the model's config.json),
4444
# and cannot be equal to it. See https://github.com/sgl-project/sglang/blob/f98366604b23e331422bf3c62d4e7410ae4fab87/python/sglang/srt/managers/tokenizer_manager.py#L638C9-L638C11
45-
max_new_tokens: 256
46-
max_tokens: 1280
45+
max_new_tokens: 15360
46+
max_tokens: 16383
4747
greedy: false
4848
temperature: 1.0
4949
top_k: 1000000
@@ -111,7 +111,7 @@ rollout:
111111
env_vars:
112112
# if use ling max v2, need to specify USE_MAX_V2 = 1
113113
USE_MAX_V2: 1
114-
image: /storage/openpsi/images/hybrid-engine-13680179-20250923154343.sif
114+
image: /storage/openpsi/images/hybrid-engine-13680179-20251015181317.sif
115115

116116
actor: &actor_ref
117117
experiment_name: ${experiment_name}
@@ -177,7 +177,7 @@ actor: &actor_ref
177177
distributed_backend: "nccl"
178178
distributed_timeout_minutes: 600
179179
enable_one_logger: false
180-
expert_model_parallel_size: 1
180+
expert_model_parallel_size: 8
181181
ffn_hidden_size: 5120
182182
first_k_dense_replace: 1
183183
global_batch_size: 512
@@ -293,7 +293,7 @@ actor: &actor_ref
293293
CUDA_LAUNCH_BLOCKING: 1
294294
# if use ling max v2, need to specify USE_MAX_V2 = 1
295295
USE_MAX_V2: 1
296-
image: /storage/openpsi/images/hybrid-engine-13680179-20250923154343.sif
296+
image: /storage/openpsi/images/hybrid-engine-13680179-20251015181317.sif
297297

298298
ref:
299299
<<: *actor_ref
@@ -306,5 +306,5 @@ recover:
306306
latest_disable_save_hf: true
307307
periodic_disable_save_hf: false
308308
latest_save_interval: 1
309-
periodic_save_interval: 20
309+
periodic_save_interval: 2
310310
fileroot: "${storage_prefix}/experiments"

areal/examples/grpo_trainer.py

Lines changed: 223 additions & 67 deletions
Large diffs are not rendered by default.

areal/extension/asystem/ascheduler/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def submit_job(self, job: Job) -> dict[str, Any]:
418418
)
419419

420420
def wait_for_jobs(
421-
self, role: str, submitted_jobs: dict[str, str], timeout: float = 300.0
421+
self, role: str, submitted_jobs: dict[str, str], timeout: float = 1200.0
422422
) -> dict[str, Worker]:
423423
"""
424424
等待作业启动并返回服务器信息
@@ -532,9 +532,6 @@ def _parse_ports_list(self, container_statuses: list) -> list:
532532
return ports_list
533533

534534
def stop_job(self, job_uid: str):
535-
# hack
536-
return
537-
"""停止作业"""
538535
logger.info(f"Stopping job with UID: {job_uid}")
539536

540537
try:

areal/extension/asystem/ascheduler/rpc_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from areal.scheduler.rpc.serialization import deserialize_value, serialize_value
1414
from areal.utils import logging
1515
from areal.utils.http import response_retryable
16-
from areal.extension.asystem.utils.async_utils import run_async_with_loop
1716

1817
logger = logging.getLogger("RPCClient")
1918

areal/extension/asystem/controller/rollout_controller.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import asyncio
8+
import time
89
from concurrent.futures import ThreadPoolExecutor
910

1011
from areal.api.alloc_mode import AllocationMode
@@ -59,7 +60,7 @@ async def _async_initialize(self, job: Job, *args, **kwargs):
5960

6061
# Wait for workers to be ready
6162
self.logger.info("Waiting for workers to be ready...")
62-
self.workers = self.scheduler.get_workers(role=job.role)
63+
self.workers = self.scheduler.get_workers(role=job.role, timeout=1200)
6364
self.logger.info(f"Workers ready: {[w.id for w in self.workers]}")
6465

6566
# Get engine class path for dynamic import on workers
@@ -92,6 +93,8 @@ async def _async_initialize(self, job: Job, *args, **kwargs):
9293
self.scheduler.async_call_engine(worker.id, "initialize", init_config)
9394
for worker, init_config in zip(self.workers, init_configs)
9495
]
96+
import time
97+
time.sleep(60)
9598
await asyncio.gather(*tasks)
9699
self.logger.info("All engines are initialized...")
97100

@@ -179,16 +182,16 @@ def _build_engine_initialize_config(
179182
main_server_addrs = [
180183
f"{worker.ip}:{worker.engine_ports[0]}"
181184
for worker in self.workers[
182-
index : index + self.alloc_mode.gen_instance_size
183-
]
185+
index: index + self.alloc_mode.gen_instance_size
186+
]
184187
if worker.engine_ports
185188
]
186189
free_addrs = [
187190
[
188191
f"{worker.ip}:{port}"
189192
for worker in self.workers[
190-
index : index + self.alloc_mode.gen_instance_size
191-
]
193+
index: index + self.alloc_mode.gen_instance_size
194+
]
192195
for port in worker.engine_ports[1:]
193196
]
194197
]
@@ -207,7 +210,74 @@ def _build_engine_initialize_config(
207210

208211
return init_configs
209212

210-
async def update_weights(self, meta: WeightUpdateMeta) -> None:
213+
def update_weights(self, meta: WeightUpdateMeta) -> None:
214+
"""Update weights - thread-safe for ThreadPoolExecutor calls."""
211215
self.logger.info("begin update_weights")
212-
execute_parallel_tasks(self.workers, self.scheduler, "update_weights", meta)
216+
self._execute_async_task_on_workers("update_weights", meta)
213217
self.logger.info("finish update_weights")
218+
219+
def set_version(self, version: int) -> None:
220+
self._version = version
221+
self.logger.info("begin set_version")
222+
self._execute_async_task_on_workers("set_version", version)
223+
self.logger.info("finish set_version")
224+
225+
def notify_event(self, event: str, global_step: int) -> None:
226+
"""Notify workers about training start/end events.
227+
228+
Args:
229+
event: "train_start" or "train_end"
230+
global_step: Current global step
231+
"""
232+
self.logger.info(f"begin notify_event global_step: {global_step}")
233+
self._execute_async_task_on_workers("notify_event", event, global_step)
234+
self.logger.info(f"finished notify_event global_step: {global_step}")
235+
return None
236+
237+
def _execute_async_task_on_workers(self, method_name: str, *args, **kwargs):
238+
def _run_async_in_thread():
239+
"""Run async code in a thread-safe manner."""
240+
# Always create a new event loop for this thread to avoid conflicts
241+
loop = asyncio.new_event_loop()
242+
asyncio.set_event_loop(loop)
243+
244+
try:
245+
async def _async_exec_func():
246+
try:
247+
self.logger.info(f"Executing {method_name} on {len(self.workers)} workers")
248+
tasks = [
249+
self.scheduler.async_call_engine(
250+
worker.id, method_name, *args, **kwargs, _should_bcast=False
251+
)
252+
for worker in self.workers
253+
]
254+
results = await asyncio.gather(*tasks, return_exceptions=True)
255+
256+
# Check for exceptions in results
257+
for i, result in enumerate(results):
258+
if isinstance(result, Exception):
259+
self.logger.error(
260+
f"Worker {self.workers[i].id} failed to execute {method_name}: {result}")
261+
else:
262+
self.logger.info(f"Worker {self.workers[i].id} successfully executed {method_name}")
263+
264+
# Re-raise if any exceptions occurred
265+
for result in results:
266+
if isinstance(result, Exception):
267+
raise result
268+
269+
return results
270+
except Exception as e:
271+
self.logger.error(f"Failed to execute {method_name} on workers: {e}")
272+
raise e
273+
274+
return loop.run_until_complete(_async_exec_func())
275+
finally:
276+
# Always close the loop we created
277+
if not loop.is_closed():
278+
loop.close()
279+
# Clear the event loop for this thread
280+
asyncio.set_event_loop(None)
281+
282+
return _run_async_in_thread()
283+

0 commit comments

Comments
 (0)