Skip to content

Commit bbe1517

Browse files
committed
update
1 parent 91cf189 commit bbe1517

File tree

2 files changed

+361
-6
lines changed

2 files changed

+361
-6
lines changed
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
import os
2+
import sys
3+
from copy import deepcopy
4+
5+
import torch.distributed as dist
6+
from torchdata.stateful_dataloader import StatefulDataLoader
7+
8+
from areal.api.alloc_mode import AllocationMode
9+
from areal.api.cli_args import GRPOConfig, load_expr_config
10+
from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta
11+
from areal.dataset import get_custom_dataset
12+
from areal.engine.ppo.actor import FSDPPPOActor
13+
from areal.engine.sglang_remote import RemoteSGLangEngine
14+
from areal.platforms import current_platform
15+
from areal.utils import seeding, stats_tracker
16+
from areal.utils.data import (
17+
broadcast_tensor_container,
18+
cycle_dataloader,
19+
tensor_container_to,
20+
)
21+
from areal.utils.device import log_gpu_stats
22+
from areal.utils.evaluator import Evaluator
23+
from areal.utils.hf_utils import load_hf_tokenizer
24+
from areal.utils.recover import RecoverHandler
25+
from areal.utils.saver import Saver
26+
from areal.utils.stats_logger import StatsLogger
27+
from areal.workflow.rlvr import RLVRWorkflow
28+
29+
from typing import TYPE_CHECKING, Optional
30+
from datasets import load_dataset
31+
from datasets.distributed import split_dataset_by_node
32+
if TYPE_CHECKING:
33+
from datasets import Dataset
34+
from transformers.processing_utils import ProcessorMixin
35+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
36+
37+
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
38+
from areal.reward.math_parser import process_results
39+
40+
return int(process_results(completions, answer)[0])
41+
42+
def load_greso_dataset(
43+
path: str,
44+
rank: int,
45+
world_size: int,
46+
type: str = "sft",
47+
split: Optional[str] = None,
48+
max_length: Optional[int] = None,
49+
tokenizer: Optional["PreTrainedTokenizerFast"] = None,
50+
processor: Optional["ProcessorMixin"] = None,
51+
**kwargs,
52+
) -> "Dataset":
53+
dataset = load_dataset("parquet", data_dir=path, split=split)
54+
55+
def process(sample):
56+
return {"messages": sample["messages"], "answer": sample["answer"]}
57+
58+
dataset = dataset.map(process)
59+
60+
# Filter out sequences longer than max_length if tokenizer and max_length are provided
61+
if max_length is not None:
62+
63+
def filter_length(sample):
64+
# Tokenize the user content to check length
65+
content = sample["messages"][0]["content"]
66+
tokens = tokenizer.encode(content)
67+
return len(tokens) <= max_length
68+
69+
dataset = dataset.filter(filter_length)
70+
71+
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
72+
return dataset
73+
74+
75+
def main(args):
76+
config, _ = load_expr_config(args, GRPOConfig)
77+
config: GRPOConfig
78+
79+
rank = int(os.getenv("RANK"))
80+
tokenizer = load_hf_tokenizer(config.tokenizer_path)
81+
82+
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
83+
allocation_mode = AllocationMode.from_str(config.allocation_mode)
84+
parallel_strategy = allocation_mode.train
85+
assert parallel_strategy is not None
86+
87+
# Initialize train engine
88+
actor = FSDPPPOActor(config=config.actor)
89+
actor.create_process_group(parallel_strategy=parallel_strategy)
90+
91+
train_dataset = load_greso_dataset(
92+
path=config.train_dataset.path,
93+
rank=actor.data_parallel_rank,
94+
world_size=actor.data_parallel_world_size,
95+
split="train",
96+
max_length=config.train_dataset.max_length,
97+
type=config.train_dataset.type,
98+
tokenizer=tokenizer,
99+
)
100+
valid_dataset = load_greso_dataset(
101+
path=config.valid_dataset.path,
102+
rank=actor.data_parallel_rank,
103+
world_size=actor.data_parallel_world_size,
104+
split="test",
105+
max_length=config.valid_dataset.max_length,
106+
type=config.valid_dataset.type,
107+
tokenizer=tokenizer,
108+
)
109+
110+
# Create dataset and dataloaders
111+
train_dataloader = StatefulDataLoader(
112+
train_dataset,
113+
batch_size=config.train_dataset.batch_size // actor.data_parallel_world_size,
114+
shuffle=config.train_dataset.shuffle,
115+
num_workers=config.train_dataset.num_workers,
116+
collate_fn=lambda x: x,
117+
drop_last=config.train_dataset.drop_last,
118+
)
119+
valid_dataloader = StatefulDataLoader(
120+
valid_dataset,
121+
batch_size=config.valid_dataset.batch_size // actor.data_parallel_world_size,
122+
shuffle=config.valid_dataset.shuffle,
123+
num_workers=config.valid_dataset.num_workers,
124+
collate_fn=lambda x: x,
125+
drop_last=config.valid_dataset.drop_last,
126+
)
127+
ft_spec = FinetuneSpec(
128+
total_train_epochs=config.total_train_epochs,
129+
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
130+
train_batch_size=config.train_dataset.batch_size,
131+
)
132+
133+
# Initialize inference engine
134+
rollout = RemoteSGLangEngine(config.rollout)
135+
rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size)
136+
eval_rollout = RemoteSGLangEngine(deepcopy(config.rollout))
137+
# NOTE: eval does not have any offpolicyness control
138+
eval_rollout.config.max_head_offpolicyness = int(1e12)
139+
eval_rollout.initialize()
140+
141+
actor.initialize(None, ft_spec)
142+
ref = None
143+
if config.actor.kl_ctl > 0 and config.ref is not None:
144+
ref = FSDPPPOActor(config=config.ref)
145+
ref.create_process_group(parallel_strategy=parallel_strategy)
146+
ref.initialize(None, ft_spec)
147+
148+
# NOTE: Weight update meta only requires address and free port of rank 0,
149+
# but `WeightUpdateMeta.from_fsdp_xccl` has to be executed on all ranks
150+
# due to `engine.get_param_specs()`.
151+
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
152+
weight_update_meta = [
153+
WeightUpdateMeta.from_fsdp_xccl(
154+
AllocationMode.from_str(config.allocation_mode), actor
155+
)
156+
]
157+
dist.broadcast_object_list(weight_update_meta, src=0)
158+
weight_update_meta = weight_update_meta[0]
159+
160+
# Create rollout workflow
161+
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
162+
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
163+
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
164+
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
165+
workflow = RLVRWorkflow(
166+
reward_fn=gsm8k_reward_fn,
167+
gconfig=config.gconfig,
168+
tokenizer=tokenizer,
169+
enable_thinking=False,
170+
dump_dir=os.path.join(
171+
StatsLogger.get_log_path(config.stats_logger), "generated"
172+
),
173+
)
174+
eval_workflow = RLVRWorkflow(
175+
reward_fn=gsm8k_reward_fn,
176+
gconfig=config.gconfig.new(temperature=0.6),
177+
tokenizer=tokenizer,
178+
enable_thinking=False,
179+
rollout_stat_scope="eval-rollout",
180+
dump_dir=os.path.join(
181+
StatsLogger.get_log_path(config.stats_logger), "generated-eval"
182+
),
183+
)
184+
185+
# Run training.
186+
saver = Saver(config.saver, ft_spec)
187+
stats_logger = StatsLogger(config.stats_logger, ft_spec)
188+
evaluator = Evaluator(config.evaluator, ft_spec)
189+
190+
recover_handler = RecoverHandler(config.recover, ft_spec)
191+
recover_info = recover_handler.load(
192+
actor,
193+
saver,
194+
evaluator,
195+
stats_logger,
196+
train_dataloader,
197+
inference_engine=rollout,
198+
weight_update_meta=weight_update_meta,
199+
)
200+
start_step = (
201+
recover_info.last_step_info.next().global_step
202+
if recover_info is not None
203+
else 0
204+
)
205+
206+
total_epochs = config.total_train_epochs
207+
steps_per_epoch = len(train_dataloader)
208+
max_steps = total_epochs * steps_per_epoch
209+
210+
data_generator = cycle_dataloader(train_dataloader)
211+
for global_step in range(start_step, max_steps):
212+
epoch = global_step // steps_per_epoch
213+
step = global_step % steps_per_epoch
214+
step_info = StepInfo(
215+
global_step=global_step,
216+
epoch=epoch,
217+
epoch_step=step,
218+
steps_per_epoch=steps_per_epoch,
219+
)
220+
221+
with stats_tracker.record_timing("rollout"):
222+
batch = None
223+
if actor.is_data_parallel_head():
224+
if config.async_training:
225+
batch = rollout.prepare_batch(
226+
train_dataloader,
227+
workflow=workflow,
228+
should_accept=lambda sample: True,
229+
)
230+
else:
231+
batch = rollout.rollout_batch(
232+
next(data_generator),
233+
workflow=workflow,
234+
should_accept=lambda sample: True,
235+
)
236+
batch = tensor_container_to(batch, actor.device)
237+
batch = broadcast_tensor_container(
238+
batch,
239+
src_rank=actor.current_data_parallel_head(),
240+
group=actor.context_and_model_parallel_group,
241+
)
242+
# Create barrier to synchronize all rollout processes.
243+
dist.barrier(device_ids=[actor.device.index])
244+
current_platform.synchronize()
245+
246+
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
247+
with stats_tracker.record_timing("recompute_logp"):
248+
logp = actor.compute_logp(batch)
249+
batch["prox_logp"] = logp
250+
log_gpu_stats("recompute logp")
251+
252+
if ref is not None:
253+
with stats_tracker.record_timing("ref_logp"):
254+
batch["ref_logp"] = ref.compute_logp(batch)
255+
log_gpu_stats("ref logp")
256+
257+
with stats_tracker.record_timing("compute_advantage"):
258+
actor.compute_advantages(batch)
259+
log_gpu_stats("compute advantages")
260+
261+
with (
262+
stats_tracker.record_timing("train_step"),
263+
stats_tracker.scope("grpo_actor"),
264+
):
265+
stats = actor.ppo_update(batch)
266+
actor.step_lr_scheduler()
267+
log_gpu_stats("ppo update")
268+
269+
# pause inference for updating weights, save, and evaluation
270+
rollout.pause()
271+
272+
with stats_tracker.record_timing("update_weights"):
273+
if dist.get_rank() == 0:
274+
future = rollout.update_weights(weight_update_meta)
275+
actor.upload_weights(weight_update_meta)
276+
if dist.get_rank() == 0:
277+
future.result()
278+
dist.barrier(device_ids=[actor.device.index])
279+
current_platform.synchronize()
280+
281+
actor.set_version(global_step + 1)
282+
rollout.set_version(global_step + 1)
283+
eval_rollout.set_version(global_step + 1)
284+
285+
with stats_tracker.record_timing("save"):
286+
saver.save(actor, epoch, step, global_step, tokenizer=tokenizer)
287+
288+
with stats_tracker.record_timing("checkpoint_for_recover"):
289+
recover_handler.dump(
290+
actor,
291+
step_info,
292+
saver,
293+
evaluator,
294+
stats_logger,
295+
train_dataloader,
296+
tokenizer=tokenizer,
297+
)
298+
299+
dist.barrier(device_ids=[actor.device.index])
300+
current_platform.synchronize()
301+
302+
with stats_tracker.record_timing("eval"):
303+
304+
def evaluate_fn():
305+
if actor.is_data_parallel_head():
306+
# Stats are logged in workflow
307+
# and will be exported later
308+
cnt = 0
309+
for data in valid_dataloader:
310+
for item in data:
311+
eval_rollout.submit(item, eval_workflow)
312+
cnt += 1
313+
eval_rollout.wait(cnt, timeout=None)
314+
dist.barrier(device_ids=[actor.device.index])
315+
current_platform.synchronize()
316+
317+
evaluator.evaluate(
318+
evaluate_fn,
319+
epoch,
320+
step,
321+
global_step,
322+
)
323+
324+
dist.barrier(device_ids=[actor.device.index])
325+
current_platform.synchronize()
326+
327+
# Upload statistics to the logger (e.g., wandb)
328+
stats[0].update(
329+
stats_tracker.export_all(reduce_group=actor.data_parallel_group)
330+
)
331+
stats_logger.commit(epoch, step, global_step, stats)
332+
333+
dist.barrier(device_ids=[actor.device.index])
334+
current_platform.synchronize()
335+
336+
# Resume rollout
337+
rollout.resume()
338+
339+
stats_logger.close()
340+
eval_rollout.destroy()
341+
rollout.destroy()
342+
if ref is not None:
343+
ref.destroy()
344+
actor.destroy()
345+
346+
347+
if __name__ == "__main__":
348+
main(sys.argv[1:])

scripts/dapo.sh

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
#!/usr/bin/env bash
22
set -euo pipefail
3-
export CUDA_VISIBLE_DEVICES=0,1,2
3+
export CUDA_VISIBLE_DEVICES=1,2
44
N_GPU=2
5-
EXP_NAME=gsm8k-dapo
6-
TRIAL_NAME=trial1
5+
EXP_NAME=greso-dapo
6+
TRIAL_NAME=trial0
77
FILE_ROOT=/data/yl/AReaL/tmp/areal/experiments
88
ACTOR_PATH=/data/yl/model/Qwen/Qwen2.5-1.5B-Instruct
9+
TRAIN_DATASET_PATH=/data/yl/dataset/greso
10+
VALID_DATASET_PATH=/data/yl/dataset/greso
911

1012
TOTAL_TRAIN_EPOCHS=1
1113

1214
python3 -m areal.launcher.local \
13-
examples/experimental/dapo/gsm8k_dapo.py \
15+
examples/experimental/dapo/greso_dapo.py \
1416
--config examples/experimental/dapo/gsm8k_dapo.yaml \
1517
experiment_name="$EXP_NAME" \
1618
trial_name="$TRIAL_NAME" \
@@ -19,7 +21,12 @@ python3 -m areal.launcher.local \
1921
cluster.n_nodes=1 \
2022
cluster.n_gpus_per_node="$N_GPU" \
2123
cluster.fileroot="$FILE_ROOT" \
24+
+gconfig.top_p=0.7 \
2225
actor.path="$ACTOR_PATH" \
2326
actor.optimizer.lr=1e-6 \
24-
actor.optimizer.weight_decay=0.1 \
25-
actor.overlong_reward_penalty=false
27+
actor.optimizer.weight_decay=0.01 \
28+
actor.overlong_reward_penalty=false \
29+
actor.ppo_n_minibatches=64 \
30+
+actor.c_clip=10.0 \
31+
train_dataset.path="$TRAIN_DATASET_PATH" \
32+
valid_dataset.path="$VALID_DATASET_PATH"

0 commit comments

Comments
 (0)