Skip to content

Commit 6876b63

Browse files
楚财峯回
authored andcommitted
PullRequest: 994 修复奖励数据处理和对齐工具,新增日志和指标计算功能
Merge branch chucai.dzq/align-onpolicy of [email protected]:inclusionAI/AReaL.git into asystem/gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/994 Reviewed-by: 峯回 <[email protected]> * align onploicy, modify dataload and rlvr
1 parent 9694495 commit 6876b63

File tree

13 files changed

+450
-228
lines changed

13 files changed

+450
-228
lines changed

areal/api/io_struct.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,41 @@
1717
from transformers import AutoProcessor
1818

1919

20+
@dataclass
21+
class LLMRequest:
22+
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
23+
text: str | None = None
24+
input_ids: list[int] = field(default_factory=list)
25+
gconfig: GenerationHyperparameters = field(
26+
default_factory=GenerationHyperparameters
27+
)
28+
metadata: dict[str, Any] = field(default_factory=dict)
29+
model_id: str | None = None
30+
31+
32+
@dataclass
33+
class LLMResponse:
34+
# outputs
35+
input_tokens: list[int] = field(default_factory=list)
36+
output_tokens: list[int] = field(default_factory=list)
37+
output_logprobs: list[float] = field(default_factory=list)
38+
output_version: int = field(default_factory=int)
39+
stop_reason: Literal["length", "stop", "interrupt", "abort"] = "stop"
40+
41+
# statistics
42+
latency: float = float("inf")
43+
ttft: float = float("inf") # Time to first token
44+
itl: list[float] = field(default_factory=list) # List of inter-token latencies
45+
46+
@property
47+
def input_len(self) -> int:
48+
return len(self.input_tokens)
49+
50+
@property
51+
def output_len(self) -> int:
52+
return len(self.output_tokens)
53+
54+
2055
@dataclass
2156
class ModelRequest:
2257
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
@@ -208,6 +243,7 @@ class SaveLoadMeta:
208243
naive_distributed: bool = False
209244
global_step: int | None = None
210245

246+
211247
@dataclass
212248
class RolloutStat:
213249
submitted: int = 0

areal/examples/configs/my001/on_policy.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ train_dataset:
1919
type: "rl"
2020

2121
scheduler:
22-
# endpoint: "http://asystem-scheduler.asystem-my001-swift.svc.sigma-my001.ml01.sgp-ml.local:8081"
23-
functioncall_service_domain: "http://110.75.237.19:8080"
22+
# endpoint: "http://asystem-scheduler.asystem-my001-swift.svc.sigma-my001.ml01.sgp-ml.local:8081"
2423
endpoint: "http://asystem-scheduler.asystem-cluster-prod-1.svc:8081"
24+
functioncall_service_domain: "http://110.75.237.19:8080"
2525
reward_model_path: "/storage/jiulin.jl/Skywork-Reward-V2-Qwen3-8B"
2626
reward_model_service_url: "http://reward-model-service.asystem-test.svc.sigma-my001.ml01.sgp-ml.local:30000/classify"
2727

@@ -43,7 +43,6 @@ gconfig:
4343
# Due to the limitations of sglang, max_new_tokens + max_prompt_len must be less than the model's context_len (set in the model's config.json),
4444
# and cannot be equal to it. See https://github.com/sgl-project/sglang/blob/f98366604b23e331422bf3c62d4e7410ae4fab87/python/sglang/srt/managers/tokenizer_manager.py#L638C9-L638C11
4545
max_new_tokens: 15360
46-
max_tokens: 16383
4746
greedy: false
4847
temperature: 1.0
4948
top_k: 1000000

areal/examples/grpo_trainer.py

Lines changed: 60 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1+
import concurrent.futures
12
import json
23
import os
34
import pprint
4-
import sys
5-
import asyncio
65
import shutil
7-
import concurrent.futures
6+
import sys
87
from concurrent.futures import ThreadPoolExecutor
98

109
from datasets import load_dataset
1110
from torchdata.stateful_dataloader import StatefulDataLoader
1211

12+
from realhf.api.core.data_api import load_hf_tokenizer
13+
1314
from areal.api.cli_args import (
1415
SchedulingStrategy,
1516
load_expr_config,
1617
)
18+
from areal.api.engine_api import WeightUpdateMeta
1719
from areal.api.io_struct import AllocationMode, FinetuneSpec
1820
from areal.extension.asystem.api.cli_args import GRPOConfig
1921
from areal.extension.asystem.ascheduler import AsystemScheduler
@@ -23,23 +25,14 @@
2325
RemoteHybridInferenceWorker,
2426
)
2527
from areal.extension.asystem.remote_hybrid_train_worker import RemoteHybridTrainWorker
26-
from areal.extension.asystem.util import ShuffleSampler, wait_future_ordered
28+
from areal.extension.asystem.utils.align_tools import summarize_rewards
29+
from areal.extension.asystem.utils.util import ShuffleSampler, wait_future_ordered
2730
from areal.utils import logging, stats_tracker
28-
from areal.utils.hf_utils import load_hf_tokenizer
2931
from areal.utils.stats_logger import StatsLogger
30-
from areal.api.engine_api import WeightUpdateMeta
3132

3233
logger = logging.getLogger("Trainer")
3334

3435

35-
def custom_collate_fn(batch):
36-
all_keys = set().union(*(d.keys() for d in batch))
37-
collated_batch = {}
38-
for key in all_keys:
39-
collated_batch[key] = [d.get(key) for d in batch]
40-
return collated_batch
41-
42-
4336
def clear_dir(path):
4437
if os.path.exists(path):
4538
for filename in os.listdir(path):
@@ -54,13 +47,9 @@ def main(args):
5447
config, _ = load_expr_config(args, GRPOConfig)
5548
config: GRPOConfig
5649

57-
if config.gconfig.max_tokens is None:
58-
logger.info(
59-
"config.gconfig.max_tokens is None, set it to max_new_tokens + max_prompt_len"
60-
)
61-
config.gconfig.max_tokens = (
62-
config.gconfig.max_new_tokens + config.train_dataset.max_length
63-
)
50+
config.gconfig.max_tokens = (
51+
config.gconfig.max_new_tokens + config.train_dataset.max_length
52+
)
6453

6554
if config.enable_colocate_mode:
6655
config.rollout.engine_config["enable_memory_saver"] = True
@@ -122,22 +111,9 @@ def main(args):
122111
train_dataset = dataset["train"]
123112
train_dataset = train_dataset.filter(
124113
lambda x: len(tokenizer.encode(x["prompt"]))
125-
<= config.train_dataset.max_length
114+
<= config.train_dataset.max_length
126115
)
127116

128-
def process(sample):
129-
messages = [
130-
{
131-
"role": "user",
132-
"content": sample["prompt"]
133-
.replace("<role>HUMAN</role>", "")
134-
.replace("<role>ASSISTANT</role>", ""),
135-
}
136-
]
137-
return {"messages": messages}
138-
139-
train_dataset = train_dataset.map(process).remove_columns(["prompt"])
140-
141117
dataloader = StatefulDataLoader(
142118
train_dataset,
143119
batch_size=config.train_dataset.batch_size,
@@ -221,16 +197,21 @@ def process(sample):
221197
if config.actor.hybrid_engine.wrap_policy.kl_ctl > 0:
222198
ref = TrainController(
223199
RemoteHybridTrainWorker,
224-
config.actor,
200+
config.ref,
225201
scheduler,
226202
)
227203

228204
allocation_mode = AllocationMode.from_str(config.allocation_mode)
229205

230206
def init_train_and_rollout_controller_helper(actor, rollout):
231207
logger.info("initializing trainer controller and rollout controller")
232-
actor.initialize(role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec,
233-
group_size=config.gconfig.n_samples, )
208+
actor.initialize(
209+
role="actor",
210+
alloc_mode=allocation_mode,
211+
ft_spec=ft_spec,
212+
group_size=config.gconfig.n_samples,
213+
enable_colocate_mode=config.enable_colocate_mode,
214+
)
234215
rollout.initialize(role="rollout", alloc_mode=allocation_mode)
235216

236217
if config.enable_colocate_mode:
@@ -254,7 +235,9 @@ def init_train_and_rollout_controller_helper(actor, rollout):
254235
)
255236

256237
wait_future_ordered(futures)
257-
logger.info(f"initialized all controllers in colocation mode {config.enable_colocate_mode}")
238+
logger.info(
239+
f"initialized all controllers in colocation mode {config.enable_colocate_mode}"
240+
)
258241
else:
259242
with ThreadPoolExecutor(max_workers=3) as executor:
260243
futures = [
@@ -268,7 +251,10 @@ def init_train_and_rollout_controller_helper(actor, rollout):
268251
storage_prefix=config.storage_prefix,
269252
),
270253
executor.submit(
271-
rollout.initialize, role="rollout", alloc_mode=allocation_mode
254+
rollout.initialize,
255+
role="rollout",
256+
alloc_mode=allocation_mode,
257+
enable_colocate_mode=config.enable_colocate_mode,
272258
),
273259
]
274260
if ref is not None:
@@ -324,15 +310,19 @@ def init_train_and_rollout_controller_helper(actor, rollout):
324310
)
325311
clear_dir(weight_update_config.path)
326312
else:
327-
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
313+
with concurrent.futures.ThreadPoolExecutor(
314+
max_workers=2
315+
) as executor:
328316
upload_future = executor.submit(
329317
actor.upload_weights, weight_update_config
330318
)
331319
update_future = executor.submit(
332320
rollout.update_weights, weight_update_config
333321
)
334322
wait_future_ordered([upload_future, update_future])
335-
logger.info(f"{weight_update_config.type} update weight succeeded, step: {step}")
323+
logger.info(
324+
f"{weight_update_config.type} update weight succeeded, step: {step}"
325+
)
336326

337327
with (
338328
stats_tracker.record_timing("rollout_step"),
@@ -360,11 +350,8 @@ def init_train_and_rollout_controller_helper(actor, rollout):
360350
reward_fn="areal.extension.asystem.math_reward.reward_fn",
361351
gconfig=config.gconfig,
362352
tokenizer=config.tokenizer_path,
363-
enable_thinking=False,
364-
dump_dir=os.path.join(
365-
f"{config.storage_prefix}/experiments/logs/root/{config.experiment_name}/{config.trial_name}",
366-
"generated",
367-
),
353+
exp_name=config.experiment_name,
354+
trial_name=config.trial_name,
368355
),
369356
)
370357
else:
@@ -375,24 +362,27 @@ def init_train_and_rollout_controller_helper(actor, rollout):
375362
reward_fn="areal.extension.asystem.math_reward.reward_fn",
376363
gconfig=config.gconfig,
377364
tokenizer=config.tokenizer_path,
378-
enable_thinking=False,
379-
dump_dir=os.path.join(
380-
f"{config.storage_prefix}/experiments/logs/root/{config.experiment_name}/{config.trial_name}",
381-
"generated",
382-
),
365+
exp_name=config.experiment_name,
366+
trial_name=config.trial_name,
383367
),
384368
)
385369

386-
#TODO: calc_training_data_metrics
387-
# with (stats_tracker.scope("training_data"), ):
388-
# calc_training_data_metrics(rollout_res)
370+
# with (
371+
# stats_tracker.scope("training_data"),
372+
# ):
373+
# calc_training_data_metrics(batch.get_data())
389374
# calc_training_data_group_metrics(
390-
# rollout_res, config.gconfig.n_samples
375+
# batch.get_data, config.gconfig.n_samples
391376
# )
392-
# calc_training_data_version_metrics(rollout_res, global_step)
393-
#
394-
logger.info(f"rollout batch res: {batch}, reward: {batch["rewards"]}")
395-
with (stats_tracker.record_timing("notify_rollout_end_event"), ):
377+
# calc_training_data_version_metrics(batch.get_data, global_step)
378+
379+
logger.info(
380+
"rollout batch reward summary: %s",
381+
summarize_rewards(batch["rewards"]),
382+
)
383+
with (
384+
stats_tracker.record_timing("notify_rollout_end_event"),
385+
):
396386
logger.info(
397387
f"start to notify_rollout_end_event, step: {step}, epoch: {epoch}"
398388
)
@@ -420,7 +410,9 @@ def init_train_and_rollout_controller_helper(actor, rollout):
420410
stats_tracker.record_timing("train_step"),
421411
stats_tracker.scope("train"),
422412
):
423-
with (stats_tracker.record_timing("notify_train_start_event"),):
413+
with (
414+
stats_tracker.record_timing("notify_train_start_event"),
415+
):
424416
logger.info(
425417
f"start to notify_train_start_event, step: {step}, epoch: {epoch}"
426418
)
@@ -429,12 +421,14 @@ def init_train_and_rollout_controller_helper(actor, rollout):
429421
f"notify_train_start_event succeeded, step: {step}, epoch: {epoch}"
430422
)
431423

432-
with (stats_tracker.record_timing("train_distributed_batch"), ):
424+
with (
425+
stats_tracker.record_timing("train_distributed_batch"),
426+
):
433427
logger.info(f"start to train, step: {step}, epoch: {epoch}")
434428
actor.train_batch(
435429
batch,
436430
loss_fn=lambda logits, batch_data: None,
437-
loss_weight_fn=lambda batch_data: None
431+
loss_weight_fn=lambda batch_data: None,
438432
)
439433
logger.info(
440434
f"train succeeded, step: {step}, epoch: {epoch}"
@@ -487,7 +481,9 @@ def init_train_and_rollout_controller_helper(actor, rollout):
487481
f"[Trainer] periodic_checkpoint recover save success, epoch:{epoch}, epoch_step: {step}, global_step:{global_step}"
488482
)
489483

490-
with (stats_tracker.record_timing("notify_train_end_event"),):
484+
with (
485+
stats_tracker.record_timing("notify_train_end_event"),
486+
):
491487
logger.info(
492488
f"start to notify_train_end_event, step: {step}, epoch: {epoch}"
493489
)

0 commit comments

Comments
 (0)