Skip to content

Commit 083766d

Browse files
Add new implementations of RL algorithms (#6383)
* add new algorithm * move common calculations * delete data * move common calculations of rewards * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 48a673d commit 083766d

File tree

5 files changed

+135
-7
lines changed

5 files changed

+135
-7
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(
101101
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
102102
beta=grpo_config.get("beta", 0.01),
103103
loss_variation=grpo_config.get("loss_variation", "sample_level"),
104+
adv=grpo_config.get("algo"),
104105
)
105106

106107
# Reference model is initialized from policy model.
@@ -137,6 +138,8 @@ def __init__(
137138
eta_min=0.1 * grpo_config.get("lr", 1e-6),
138139
)
139140

141+
self.adv = grpo_config.get("algo")
142+
140143
def setup(self):
141144
super().setup()
142145
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
@@ -204,9 +207,23 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
204207
# [minibatch_size x num_generations]
205208
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
206209

207-
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
208-
# [minibatch_size x num_generations]
209-
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
210+
if self.adv == "GRPO" or self.adv == "DAPO":
211+
212+
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
213+
# [minibatch_size x num_generations]
214+
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
215+
216+
elif self.adv == "REINFORCE_PPB":
217+
218+
# [minibatch_size x num_generations]
219+
advantages = ((reward - reward_mean)).unsqueeze(dim=-1)
220+
221+
elif self.adv == "RLOO":
222+
223+
advantages = (
224+
reward * self.num_generations / (self.num_generations - 1)
225+
- reward_mean * self.num_generations / (self.num_generations - 1)
226+
).unsqueeze(dim=-1)
210227

211228
# [minibatch_size x num_of_generation]
212229
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
@@ -358,10 +375,34 @@ def _criterion(outputs, inputs):
358375
per_token_kl = 0.0
359376
kl.append(torch.tensor(0.0))
360377

378+
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1)
379+
380+
if self.adv == "REINFORCE_PPB":
381+
382+
inputs["advantages"] = inputs["advantages"] - self.policy_loss_fn.beta * per_token_kl
383+
advantages_forward_micro_batch_mean = torch.sum(
384+
inputs["advantages"] * inputs["action_mask"]
385+
) / (torch.sum(inputs["action_mask"]) + 1e-4)
386+
advantages_forward_micro_batch_std = torch.rsqrt(
387+
torch.sum(
388+
(inputs["advantages"] - advantages_forward_micro_batch_mean) ** 2
389+
* inputs["action_mask"]
390+
)
391+
/ (torch.sum(inputs["action_mask"]) + 1e-4)
392+
+ 1e-8
393+
)
394+
inputs["advantages"] = (
395+
(inputs["advantages"] - advantages_forward_micro_batch_mean)
396+
* inputs["action_mask"]
397+
/ (advantages_forward_micro_batch_std)
398+
)
399+
400+
per_token_kl = 0.0
401+
361402
loss, _ = self.policy_loss_fn(
362403
action_log_probs,
363404
inputs["old_action_log_probs"],
364-
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
405+
inputs["advantages"],
365406
per_token_kl,
366407
inputs["action_mask"],
367408
loss_mask=inputs["loss_mask"],
@@ -420,10 +461,39 @@ def _criterion(outputs, inputs):
420461
per_token_kl = 0.0
421462
kl = None
422463

464+
(
465+
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1)
466+
- self.policy_loss_fn.beta * per_token_kl
467+
)
468+
469+
if self.adv == "REINFORCE_PPB":
470+
471+
advantages_forward_micro_batch = (
472+
advantages_forward_micro_batch - self.policy_loss_fn.beta * per_token_kl
473+
)
474+
advantages_forward_micro_batch_mean = torch.sum(
475+
advantages_forward_micro_batch * action_mask_forward_micro_batch
476+
) / (torch.sum(action_mask_forward_micro_batch) + 1e-4)
477+
advantages_forward_micro_batch_std = torch.rsqrt(
478+
torch.sum(
479+
(advantages_forward_micro_batch - advantages_forward_micro_batch_mean) ** 2
480+
* action_mask_forward_micro_batch
481+
)
482+
/ (torch.sum(action_mask_forward_micro_batch) + 1e-4)
483+
+ 1e-8
484+
)
485+
advantages_forward_micro_batch = (
486+
(advantages_forward_micro_batch - advantages_forward_micro_batch_mean)
487+
* action_mask_forward_micro_batch
488+
/ (advantages_forward_micro_batch_std)
489+
)
490+
491+
per_token_kl = 0.0
492+
423493
loss, _ = self.policy_loss_fn(
424494
action_log_probs,
425495
old_action_log_probs_micro_batch,
426-
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
496+
advantages_forward_micro_batch,
427497
per_token_kl,
428498
action_mask_forward_micro_batch,
429499
loss_mask=loss_mask_forward_micro_batch,

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from .grpo_consumer import GRPOConsumer
1010
from .producer import SimpleProducer
1111

12-
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
12+
ALGO_MAP = {
13+
"Simple": SimpleConsumer,
14+
"GRPO": GRPOConsumer,
15+
"DAPO": GRPOConsumer,
16+
"REINFORCE_PPB": GRPOConsumer,
17+
"RLOO": GRPOConsumer,
18+
}
1319

1420

1521
def get_jsonl_size_fast(path: str) -> int:
@@ -66,6 +72,7 @@ def launch_distributed(
6672
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
6773

6874
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
75+
6976
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
7077

7178
dataset_path = train_dataset_config["path"]

applications/ColossalChat/coati/distributed/loss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ def __init__(
1616
clip_eps_high: float = 0.2,
1717
beta: float = 0.01,
1818
loss_variation: str = "sample_level",
19+
adv: str = "GRPO",
1920
) -> None:
2021
super().__init__()
2122
self.clip_eps_low = clip_eps_low
2223
self.clip_eps_high = clip_eps_high
2324
self.beta = beta
2425
self.loss_variation = loss_variation
2526
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
27+
self.adv = adv
2628

2729
def forward(
2830
self,

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def __init__(
118118
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
119119
self.tokenizer.padding_side = "left"
120120

121+
if self.tokenizer.pad_token_id is None:
122+
self.tokenizer.pad_token = self.tokenizer.eos_token
123+
121124
# init dataloader
122125
train_dataset_path = train_dataset_config.pop("path")
123126
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)

applications/ColossalChat/rl_example.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137
)
138138

139139
# GRPO parameters
140-
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
140+
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"])
141141
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
142142
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
143143
parser.add_argument(
@@ -292,6 +292,7 @@
292292
if args.algo == "GRPO":
293293
# Default Settings
294294
grpo_config = {
295+
"algo": "GRPO",
295296
"lr": args.learning_rate,
296297
"train_microbatch_size": args.train_microbatch_size,
297298
"beta": args.kl_coeff, # KL penalty coefficient
@@ -313,6 +314,7 @@
313314
elif args.algo == "DAPO":
314315
# DAPO variant settings
315316
grpo_config = {
317+
"algo": "DAPO",
316318
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
317319
"lr": args.learning_rate,
318320
"train_microbatch_size": args.train_microbatch_size,
@@ -339,6 +341,50 @@
339341
else None
340342
),
341343
}
344+
elif args.algo == "REINFORCE_PPB":
345+
# Default Settings
346+
grpo_config = {
347+
"algo": "REINFORCE_PPB",
348+
"lr": args.learning_rate,
349+
"train_microbatch_size": args.train_microbatch_size,
350+
"beta": args.kl_coeff, # KL penalty coefficient
351+
"loss_variation": "sample_level",
352+
"reward_fn_type": args.reward_type,
353+
"max_length": args.max_new_tokens + args.max_prompt_tokens,
354+
"max_new_tokens": args.max_new_tokens,
355+
"response_format_tags": (
356+
{
357+
"think_start": {"text": "<think>", "num_occur": 1},
358+
"think_end": {"text": "</think>", "num_occur": 1},
359+
"answer_start": {"text": "<answer>", "num_occur": 1},
360+
"answer_end": {"text": "</answer>", "num_occur": 1},
361+
}
362+
if args.reward_type == "think_answer_tags"
363+
else None
364+
),
365+
}
366+
elif args.algo == "RLOO":
367+
# Default Settings
368+
grpo_config = {
369+
"algo": "RLOO",
370+
"lr": args.learning_rate,
371+
"train_microbatch_size": args.train_microbatch_size,
372+
"beta": args.kl_coeff, # KL penalty coefficient
373+
"loss_variation": "sample_level",
374+
"reward_fn_type": args.reward_type,
375+
"max_length": args.max_new_tokens + args.max_prompt_tokens,
376+
"max_new_tokens": args.max_new_tokens,
377+
"response_format_tags": (
378+
{
379+
"think_start": {"text": "<think>", "num_occur": 1},
380+
"think_end": {"text": "</think>", "num_occur": 1},
381+
"answer_start": {"text": "<answer>", "num_occur": 1},
382+
"answer_end": {"text": "</answer>", "num_occur": 1},
383+
}
384+
if args.reward_type == "think_answer_tags"
385+
else None
386+
),
387+
}
342388
else:
343389
raise ValueError(f"Unsupported algorithm: {args.algo}")
344390
if args.reward_type == "code":

0 commit comments

Comments
 (0)