Skip to content

Commit 1503b06

Browse files
楚财峯回
authored andcommitted
PullRequest: 958 trainer debug
Merge branch chucai.dzq/trainer-debug of [email protected]:inclusionAI/AReaL.git into asystem/gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/958 Reviewed-by: 峯回 <[email protected]> * trainer debug
1 parent 8c25b5a commit 1503b06

File tree

11 files changed

+286
-308
lines changed

11 files changed

+286
-308
lines changed

areal/controller/rollout_controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,6 @@ def wait(self, count: int, timeout: float | None = None) -> DistributedBatch:
403403
capacity = self.get_capacity()
404404
# Submit pending tasks
405405
self.logger.info(f"Capacity: {capacity}, pending inputs: {len(self._pending_inputs)}")
406-
407406
for _ in range(capacity):
408407
if len(self._pending_inputs) == 0:
409408
break
@@ -484,6 +483,7 @@ def rollout_batch(
484483
A concatenated batch of trajectory results
485484
"""
486485
# Submit all requests
486+
487487
for item in data:
488488
self.submit(
489489
item,

areal/examples/configs/my001/on_policy.yaml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
experiment_name: mini-model
2-
trial_name: on-policy
3-
allocation_mode: "sglang:d8t4p1+d8t1p4"
2+
trial_name: test
3+
allocation_mode: "sglang:d2t4p1+d2t1p4"
44
seed: 42
5-
total_train_epochs: 10
6-
total_train_steps: 1145
5+
total_train_epochs: 1
6+
total_train_steps: 2
77
weight_update_type: "astate"
8-
enable_colocate_mode: true
8+
enable_colocate_mode: false
9+
async_training: false
910

1011
storage_prefix: "/storage/openpsi"
1112

@@ -14,7 +15,7 @@ train_dataset:
1415
path: "/storage/dataset/nlp/areal/moe_lite_math_0527_merge_train_areal.jsonl"
1516
shuffle: true
1617
max_length: 1024
17-
batch_size: 64
18+
batch_size: 8
1819
type: "rl"
1920

2021
scheduler:
@@ -36,13 +37,13 @@ stats_logger:
3637
path: "/home/admin/logs/tfevent/asystem"
3738

3839
gconfig:
39-
n_samples: 8
40+
n_samples: 1
4041
min_new_tokens: 0
4142
# NOTE!!
4243
# Due to the limitations of sglang, max_new_tokens + max_prompt_len must be less than the model's context_len (set in the model's config.json),
4344
# and cannot be equal to it. See https://github.com/sgl-project/sglang/blob/f98366604b23e331422bf3c62d4e7410ae4fab87/python/sglang/srt/managers/tokenizer_manager.py#L638C9-L638C11
44-
max_new_tokens: 15360
45-
max_tokens: 16383
45+
max_new_tokens: 256
46+
max_tokens: 1280
4647
greedy: false
4748
temperature: 1.0
4849
top_k: 1000000
@@ -176,7 +177,7 @@ actor: &actor_ref
176177
distributed_backend: "nccl"
177178
distributed_timeout_minutes: 600
178179
enable_one_logger: false
179-
expert_model_parallel_size: 8
180+
expert_model_parallel_size: 1
180181
ffn_hidden_size: 5120
181182
first_k_dense_replace: 1
182183
global_batch_size: 512

areal/examples/grpo_trainer.py

Lines changed: 40 additions & 232 deletions
Large diffs are not rendered by default.

areal/extension/asystem/ascheduler/rpc_client.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from areal.scheduler.rpc.serialization import deserialize_value, serialize_value
1414
from areal.utils import logging
1515
from areal.utils.http import response_retryable
16+
from areal.extension.asystem.utils.async_utils import run_async_with_loop
1617

1718
logger = logging.getLogger("RPCClient")
1819

@@ -172,9 +173,18 @@ def call_engine(self, worker_id, method, max_retries=3, *args, **kwargs):
172173
WorkerFailedError: If worker process has failed
173174
EngineCallError: If method call fails
174175
"""
175-
return asyncio.run(
176-
self.async_call_engine(worker_id, method, max_retries, *args, **kwargs)
177-
)
176+
# 创建新的事件循环并运行异步任务
177+
loop = asyncio.new_event_loop()
178+
asyncio.set_event_loop(loop)
179+
try:
180+
return loop.run_until_complete(
181+
self.async_call_engine(worker_id, method, max_retries, *args, **kwargs)
182+
)
183+
finally:
184+
try:
185+
loop.close()
186+
except Exception:
187+
pass
178188

179189
async def async_call_engine(
180190
self, worker_id, method, max_retries=3, *args, **kwargs
@@ -241,7 +251,7 @@ async def async_call_engine_with_serialized_data(
241251
for attempt in range(1, max_retries + 1):
242252
try:
243253
logger.info(
244-
f"Async calling method '{method}' on worker '{worker_id}' (attempt {attempt})"
254+
f"Async calling method '{method}' on worker '{worker_id}' (attempt {attempt}), url: {url}"
245255
)
246256

247257
response = await self._async_http_client.post(

areal/extension/asystem/controller/rollout_controller.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from concurrent.futures import ThreadPoolExecutor
99

1010
from areal.api.alloc_mode import AllocationMode
11+
from areal.api.io_struct import WeightUpdateMeta
1112
from areal.api.cli_args import InferenceEngineConfig
1213
from areal.api.engine_api import InferenceEngine
1314
from areal.api.scheduler_api import Job, Scheduler
@@ -20,6 +21,7 @@
2021
from areal.extension.asystem.remote_hybrid_inference_worker import (
2122
RemoteHypidInferenceInitConfig,
2223
)
24+
from areal.extension.asystem.controller.util import execute_parallel_tasks
2325
from areal.utils import logging
2426

2527

@@ -204,3 +206,8 @@ def _build_engine_initialize_config(
204206
init_configs.append(init_config)
205207

206208
return init_configs
209+
210+
async def update_weights(self, meta: WeightUpdateMeta) -> None:
211+
self.logger.info("begin update_weights")
212+
execute_parallel_tasks(self.workers, self.scheduler, "update_weights", meta)
213+
self.logger.info("finish update_weights")

areal/extension/asystem/controller/train_controller.py

Lines changed: 9 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
from typing import Any
1313
from areal.extension.asystem.api.cli_args import TrainEngineConfig
1414
from areal.api.engine_api import TrainEngine
15-
from areal.api.io_struct import AllocationMode, FinetuneSpec
15+
from areal.api.io_struct import FinetuneSpec
1616
from areal.api.scheduler_api import Job, Scheduler
1717
from areal.controller.train_controller import TrainController as BaseTrainController
18+
from areal.extension.asystem.controller.util import execute_parallel_tasks, calc_metrics
1819
from areal.extension.asystem.remote_hybrid_train_worker import RemoteMegatronInitConfig
1920
from areal.utils import logging, stats_tracker
2021
from areal.controller.batch import DistributedBatch
@@ -23,55 +24,6 @@
2324
logger = logging.getLogger("TrainController")
2425

2526

26-
def _execute_parallel_tasks(workers, scheduler, method_name, *args):
27-
"""Execute tasks in parallel across all workers.
28-
29-
This is a helper function to reduce code duplication when executing
30-
the same method on all workers with identical parameters.
31-
32-
Parameters
33-
----------
34-
workers : list
35-
List of worker objects
36-
scheduler : Scheduler
37-
Scheduler instance for async calls
38-
method_name : str
39-
Name of the method to call on each worker's engine
40-
*args, **kwargs
41-
Arguments to pass to the method
42-
43-
Returns
44-
-------
45-
list
46-
Results from all workers
47-
48-
Raises
49-
------
50-
RuntimeError
51-
If any worker fails to execute the task
52-
"""
53-
tasks = [
54-
scheduler.async_call_engine(
55-
worker.id, method_name, *args, _should_bcast=False
56-
)
57-
for worker in workers
58-
]
59-
60-
try:
61-
return asyncio.run(asyncio.gather(*tasks, return_exceptions=False))
62-
except KeyboardInterrupt:
63-
raise
64-
except Exception as e:
65-
raise RuntimeError(f"{method_name} failed, error: {e}")
66-
67-
68-
def _calc_metrics(batch_inputs):
69-
# seqlen std
70-
seqlens = [td["seqlen"].sum().item() for td in batch_inputs]
71-
seqlen_std = torch.tensor(seqlens).float().std().item()
72-
stats_tracker.scalar(**{"seqlen_std": seqlen_std})
73-
74-
7527
class TrainController(BaseTrainController):
7628
"""ASystem-specific TrainController.
7729
@@ -218,7 +170,7 @@ def train_batch(
218170
with (stats_tracker.record_timing("train_batch_data_split"), ):
219171
batches = input_.chunk_by_ffd(self.group_size, self.dp_size)
220172

221-
_calc_metrics(batches)
173+
calc_metrics(batches)
222174

223175
tasks = [
224176
self.scheduler.async_call_engine(
@@ -286,15 +238,17 @@ def compute_logp(self, input_: DistributedBatch) -> Tensor:
286238

287239
def upload_weights(self, meta: WeightUpdateMeta):
288240
"""Upload weights to the inference engine."""
289-
_execute_parallel_tasks(self.workers, self.scheduler, "upload_weights", meta)
241+
self.logger.info("begin upload_weights")
242+
execute_parallel_tasks(self.workers, self.scheduler, "upload_weights", meta)
243+
self.logger.info("finished upload_weights")
290244

291245
def save(self, meta: SaveLoadMeta):
292246
"""Save model weights (and optimizer states) for later use."""
293-
_execute_parallel_tasks(self.workers, self.scheduler, "save", meta)
247+
execute_parallel_tasks(self.workers, self.scheduler, "save", meta)
294248

295249
def load(self, meta: SaveLoadMeta):
296250
"""Load model weights and optimizer states from a file."""
297-
_execute_parallel_tasks(self.workers, self.scheduler, "load", meta)
251+
execute_parallel_tasks(self.workers, self.scheduler, "load", meta)
298252

299253
def notify_event(self, event: str, global_step: int) -> None:
300254
"""Notify workers about training start/end events.
@@ -303,5 +257,5 @@ def notify_event(self, event: str, global_step: int) -> None:
303257
event: "train_start" or "train_end"
304258
global_step: Current global step
305259
"""
306-
_execute_parallel_tasks(self.workers, self.scheduler, "notify_event", event, global_step)
260+
execute_parallel_tasks(self.workers, self.scheduler, "notify_event", event, global_step)
307261
return None
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import asyncio
2+
from concurrent.futures import ThreadPoolExecutor
3+
4+
import torch
5+
6+
from areal.utils import stats_tracker
7+
from areal.extension.asystem.utils.async_utils import run_async_with_loop
8+
9+
10+
def execute_parallel_tasks(workers, scheduler, method_name, *args):
11+
"""Execute tasks in parallel across all workers.
12+
13+
This is a helper function to reduce code duplication when executing
14+
the same method on all workers with identical parameters.
15+
16+
Parameters
17+
----------
18+
workers : list
19+
List of worker objects
20+
scheduler : Scheduler
21+
Scheduler instance for async calls
22+
method_name : str
23+
Name of the method to call on each worker's engine
24+
*args, **kwargs
25+
Arguments to pass to the method
26+
27+
Returns
28+
-------
29+
list
30+
Results from all workers
31+
32+
Raises
33+
------
34+
RuntimeError
35+
If any worker fails to execute the task
36+
"""
37+
logger.info(f"[DEBUG] execute_parallel_tasks called with method: {method_name}, workers: {[w.id for w in workers]}")
38+
tasks = [
39+
scheduler.async_call_engine(
40+
worker.id, method_name, *args, _should_bcast=False
41+
)
42+
for worker in workers
43+
]
44+
45+
try:
46+
logger.info(f"[DEBUG] Created {len(tasks)} async tasks")
47+
# 创建新的事件循环并运行所有任务
48+
loop = asyncio.new_event_loop()
49+
asyncio.set_event_loop(loop)
50+
try:
51+
logger.info(f"[DEBUG] Starting async execution")
52+
result = loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=False))
53+
logger.info(f"[DEBUG] Async execution completed successfully")
54+
return result
55+
finally:
56+
try:
57+
loop.close()
58+
except Exception:
59+
pass
60+
except KeyboardInterrupt:
61+
raise
62+
except Exception as e:
63+
logger.error(f"[DEBUG] execute_parallel_tasks failed: {str(e)}")
64+
raise RuntimeError(f"{method_name} failed, error: {e}")
65+
66+
67+
def calc_metrics(batch_inputs):
68+
# seqlen std
69+
seqlens = [td["seqlen"].sum().item() for td in batch_inputs]
70+
seqlen_std = torch.tensor(seqlens).float().std().item()
71+
stats_tracker.scalar(**{"seqlen_std": seqlen_std})

areal/extension/asystem/remote_hybrid_inference_worker.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ async def _rollout_thread_async(self):
232232
):
233233
data, workflow = self.input_queue.get_nowait()
234234

235-
# logger.info(f"Get data from puller: {data}")
235+
logger.info(f"_rollout_thread_async before arun_episode data: {data}")
236236
task = asyncio.create_task(
237237
(
238238
workflow.arun_episode(self, data)
@@ -427,6 +427,8 @@ def get_capacity(self):
427427
return capacity
428428

429429
def update_weights(self, meta):
430+
logger.info(f"[DEBUG] update_weights called with meta: {meta}")
431+
logger.info(f"[DEBUG] Available addresses: {self.addresses}")
430432
self._update_weights(meta)
431433
return True
432434

@@ -460,7 +462,7 @@ def update_single_server(addr):
460462
wait_future_ordered(futures)
461463

462464
logger.info(
463-
f"Loading weights done in {(time.time_ns() - load_timestamp) / 1e6:.2f} ms, updated version: {meta.model_version}"
465+
f"Loading weights done in {(time.time_ns() - load_timestamp) / 1e6:.2f} ms"
464466
)
465467
elif meta.type == "nccl" or meta.type == "astate":
466468
load_timestamp = time.time_ns()
@@ -498,15 +500,15 @@ def update_single_server(addr):
498500
wait_future_ordered(futures)
499501

500502
logger.info(
501-
f"Loading weights done in {(time.time_ns() - load_timestamp) / 1e6:.2f} ms, updated version: {meta.model_version}"
503+
f"Loading weights done in {(time.time_ns() - load_timestamp) / 1e6:.2f} ms"
502504
)
503505
else:
504506
raise FrameworkError(
505507
"FrameworkError",
506508
"InferenceWorkerError",
507509
f"Unknown weight update type {meta.type}",
508510
)
509-
self.set_version(meta.model_version)
511+
# self.set_version(meta.model_version)
510512

511513
def pause(self):
512514
self.paused.set()

areal/extension/asystem/remote_hybrid_train_worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(self, config: TrainEngineConfig):
4646
self.megatron_addr = None
4747
self.global_step = self.config.global_step
4848
self.global_rank = 0
49+
self._version: int = 0
4950

5051
# initialization
5152
self.initialized = False
@@ -930,6 +931,11 @@ def _compute_logprobs(
930931

931932
return None
932933

934+
def set_version(self, version: int):
935+
self._version = version
936+
937+
def get_version(self) -> int:
938+
return self._version
933939

934940
def serialize_and_compress(data):
935941
serialized_data = cloudpickle.dumps(data)

0 commit comments

Comments
 (0)