Skip to content

Commit 72c8da4

Browse files
楚财峯回
authored andcommitted
PullRequest: 944 适配开源和控制器的训练控制器修改
Merge branch chucai.dzq/train-controller-adapt-opensource of [email protected]:inclusionAI/AReaL.git into asystem/gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/944 Reviewed-by: 峯回 <[email protected]> * train controller adapt controller * train controller adapt opensource
1 parent e02f4a5 commit 72c8da4

File tree

3 files changed

+179
-11
lines changed

3 files changed

+179
-11
lines changed

areal/examples/configs/my001/on_policy.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ train_dataset:
1818
type: "rl"
1919

2020
scheduler:
21-
endpoint: "http://asystem-scheduler.asystem-my001-swift.svc.sigma-my001.ml01.sgp-ml.local:8081"
21+
# endpoint: "http://asystem-scheduler.asystem-my001-swift.svc.sigma-my001.ml01.sgp-ml.local:8081"
2222
functioncall_service_domain: "http://110.75.237.19:8080"
23+
endpoint: "http://asystem-scheduler.asystem-cluster-prod-1.svc:8081"
2324
reward_model_path: "/storage/jiulin.jl/Skywork-Reward-V2-Qwen3-8B"
2425
reward_model_service_url: "http://reward-model-service.asystem-test.svc.sigma-my001.ml01.sgp-ml.local:30000/classify"
2526

@@ -111,8 +112,6 @@ actor: &actor_ref
111112
experiment_name: ${experiment_name}
112113
trial_name: ${trial_name}
113114
hybrid_engine:
114-
experiment_name: ${experiment_name}
115-
trial_name: ${trial_name}
116115
group_size: ${gconfig.n_samples}
117116
train_bs_n_seqs: ${train_dataset.batch_size}
118117
max_tokens_per_mb: 16384

areal/examples/grpo_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ def init_train_and_rollout_controller_helper(actor, rollout):
262262
role="actor",
263263
alloc_mode=allocation_mode,
264264
ft_spec=ft_spec,
265+
group_size=config.gconfig.n_samples,
266+
enable_colocate_mode=config.enable_colocate_mode,
267+
storage_prefix=config.storage_prefix,
265268
),
266269
executor.submit(
267270
rollout.initialize, role="rollout", alloc_mode=allocation_mode

areal/extension/asystem/controller/train_controller.py

Lines changed: 174 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,73 @@
55
"""
66

77
import asyncio
8+
import torch
89

9-
from areal.api.cli_args import TrainEngineConfig
10+
from torch import Tensor
11+
from collections.abc import Callable
12+
from typing import Any
13+
from areal.extension.asystem.api.cli_args import TrainEngineConfig
1014
from areal.api.engine_api import TrainEngine
1115
from areal.api.io_struct import AllocationMode, FinetuneSpec
1216
from areal.api.scheduler_api import Job, Scheduler
1317
from areal.controller.train_controller import TrainController as BaseTrainController
1418
from areal.extension.asystem.remote_hybrid_train_worker import RemoteMegatronInitConfig
15-
from areal.utils import logging
19+
from areal.utils import logging, stats_tracker
20+
from areal.controller.batch import DistributedBatch
21+
from areal.api.io_struct import AllocationMode, SaveLoadMeta, WeightUpdateMeta
1622

1723
logger = logging.getLogger("TrainController")
1824

1925

26+
def _execute_parallel_tasks(workers, scheduler, method_name, *args):
27+
"""Execute tasks in parallel across all workers.
28+
29+
This is a helper function to reduce code duplication when executing
30+
the same method on all workers with identical parameters.
31+
32+
Parameters
33+
----------
34+
workers : list
35+
List of worker objects
36+
scheduler : Scheduler
37+
Scheduler instance for async calls
38+
method_name : str
39+
Name of the method to call on each worker's engine
40+
*args, **kwargs
41+
Arguments to pass to the method
42+
43+
Returns
44+
-------
45+
list
46+
Results from all workers
47+
48+
Raises
49+
------
50+
RuntimeError
51+
If any worker fails to execute the task
52+
"""
53+
tasks = [
54+
scheduler.async_call_engine(
55+
worker.id, method_name, *args, _should_bcast=False
56+
)
57+
for worker in workers
58+
]
59+
60+
try:
61+
return asyncio.run(asyncio.gather(*tasks, return_exceptions=False))
62+
except KeyboardInterrupt:
63+
raise
64+
except Exception as e:
65+
raise RuntimeError(f"{method_name} failed, error: {e}")
66+
67+
68+
def _calc_metrics(batch_inputs):
69+
# seqlen std
70+
seqlens = [td["seqlen"].sum().item() for td in batch_inputs]
71+
seqlen_std = torch.tensor(seqlens).float().std().item()
72+
stats_tracker.scalar(**{"seqlen_std": seqlen_std})
73+
74+
2075
class TrainController(BaseTrainController):
2176
"""ASystem-specific TrainController.
2277
@@ -69,9 +124,16 @@ def initialize(
69124
self.logger = logging.getLogger("[TrainController]")
70125

71126
# Store configuration
127+
self.parallel_strategy = alloc_mode.train
72128
self._worker_role = role
73129
self.alloc_mode = alloc_mode
74-
self.parallel_strategy = alloc_mode.train
130+
self.world_size = self.alloc_mode.train.world_size
131+
self.dp_size = self.alloc_mode.train.dp_size
132+
self.tp_size = self.alloc_mode.train.tp_size
133+
self.pp_size = self.alloc_mode.train.pp_size
134+
self.group_size = kwargs.get("group_size")
135+
self.enable_colocate_mode = kwargs.get("enable_colocate_mode")
136+
self.storage_prefix = kwargs.get("storage_prefix")
75137

76138
# Create job for scheduler
77139
job = Job(
@@ -99,10 +161,6 @@ def initialize(
99161
asyncio.run(self._async_create_engines(engine_path))
100162
asyncio.run(self._async_initialize(job, ft_spec, **kwargs))
101163

102-
# Identify DP head workers
103-
# todo: @chucai, implement this, record rank info in hybrid train worker and implement is_data_parallel_head...
104-
# self._identify_dp_heads()
105-
106164
self.logger.info("TrainController initialization complete")
107165

108166
async def _async_initialize(self, job: Job, ft_spec: FinetuneSpec, **kwargs):
@@ -121,7 +179,17 @@ async def _async_initialize(self, job: Job, ft_spec: FinetuneSpec, **kwargs):
121179
for worker, init_config in zip(self.workers, init_configs)
122180
]
123181

124-
await asyncio.gather(*tasks)
182+
self.rank_info = {}
183+
try:
184+
gather_results = await asyncio.gather(*tasks, return_exceptions=False)
185+
except Exception as e:
186+
self.logger.error(f"Initialization failed with error: {e}")
187+
raise RuntimeError(f"Failed to initialize workers, error: {e}")
188+
189+
for worker_index, result in enumerate(gather_results):
190+
self.rank_info[worker_index] = result
191+
self.logger.info(f"Worker {worker_index} succeeded: {result}")
192+
125193
self.logger.info("All engines are initialized!")
126194

127195
def _build_engine_initialize_config(
@@ -139,3 +207,101 @@ def _build_engine_initialize_config(
139207
)
140208
for index, worker in enumerate(self.workers)
141209
]
210+
211+
def train_batch(
212+
self,
213+
input_: DistributedBatch,
214+
loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
215+
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
216+
) -> dict[str, float]:
217+
self.logger.info(f"start to train_batch")
218+
with (stats_tracker.record_timing("train_batch_data_split"), ):
219+
batches = input_.chunk_by_ffd(self.group_size, self.dp_size)
220+
221+
_calc_metrics(batches)
222+
223+
tasks = [
224+
self.scheduler.async_call_engine(
225+
worker.id, "train_batch", batches[self.rank_info[index]["dp_rank"]], _should_bcast=False
226+
)
227+
for index, worker in enumerate(self.workers)
228+
]
229+
230+
try:
231+
results = asyncio.run(asyncio.gather(*tasks, return_exceptions=False))
232+
except KeyboardInterrupt:
233+
raise
234+
except Exception as e:
235+
raise RuntimeError(f"train_batch failed, error: {e}")
236+
237+
for worker_result in results:
238+
if len(worker_result) > 1:
239+
for minibatch in worker_result:
240+
stats_tracker.scalar(**minibatch)
241+
else:
242+
stats_tracker.scalar(**worker_result[0])
243+
244+
return {}
245+
246+
def compute_logp(self, input_: DistributedBatch) -> Tensor:
247+
"""Update the model with a batch of data and a loss function."""
248+
logger.info(f"start to compute_logp")
249+
with (
250+
stats_tracker.record_timing("compute_logp_data_split"),
251+
):
252+
batches = input_.chunk(self.dp_size)
253+
tasks = [
254+
self.scheduler.async_call_engine(
255+
worker.id, "compute_logprobs", batches[self.rank_info[index]["dp_rank"]], _should_bcast=False
256+
)
257+
for index, worker in enumerate(self.workers)
258+
]
259+
260+
try:
261+
results = asyncio.run(asyncio.gather(*tasks, return_exceptions=False))
262+
except KeyboardInterrupt:
263+
raise
264+
except Exception as e:
265+
raise RuntimeError(f"compute_logp failed, error: {e}")
266+
267+
# cat tensor from dp head with padding
268+
tensors_from_dp_heads = results[: self.dp_size]
269+
if not tensors_from_dp_heads:
270+
return torch.tensor([])
271+
272+
# Find max length in dim 1
273+
max_len = max(t.shape[1] for t in tensors_from_dp_heads)
274+
max_len_all = max(t.shape[1] for t in results)
275+
assert max_len_all == max_len
276+
# Pad all tensors to max length
277+
padded_tensors = []
278+
for t in tensors_from_dp_heads:
279+
pad_size = max_len - t.shape[1]
280+
padded = torch.nn.functional.pad(t, (0, pad_size), value=0.0)
281+
padded_tensors.append(padded)
282+
283+
# Concatenate along batch dimension
284+
concatenated_result = torch.cat(padded_tensors, dim=0)
285+
return concatenated_result
286+
287+
def upload_weights(self, meta: WeightUpdateMeta):
288+
"""Upload weights to the inference engine."""
289+
_execute_parallel_tasks(self.workers, self.scheduler, "upload_weights", meta)
290+
291+
def save(self, meta: SaveLoadMeta):
292+
"""Save model weights (and optimizer states) for later use."""
293+
_execute_parallel_tasks(self.workers, self.scheduler, "save", meta)
294+
295+
def load(self, meta: SaveLoadMeta):
296+
"""Load model weights and optimizer states from a file."""
297+
_execute_parallel_tasks(self.workers, self.scheduler, "load", meta)
298+
299+
def notify_event(self, event: str, global_step: int) -> None:
300+
"""Notify workers about training start/end events.
301+
302+
Args:
303+
event: "train_start" or "train_end"
304+
global_step: Current global step
305+
"""
306+
_execute_parallel_tasks(self.workers, self.scheduler, "notify_event", event, global_step)
307+
return None

0 commit comments

Comments
 (0)