-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Add more training models and RLHF algorithms #6368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: grpo-latest
Are you sure you want to change the base?
Changes from all commits
47ee955
77bd4a4
dd08277
9da096f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -132,6 +132,8 @@ def __init__( | |
eta_min=0.1 * grpo_config.get("lr", 1e-6), | ||
) | ||
|
||
self.adv = grpo_config.get("algo") | ||
|
||
def setup(self): | ||
super().setup() | ||
if (not self.plugin.pp_size > 1 and self.rank == 0) or ( | ||
|
@@ -180,23 +182,72 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: | |
response_length = torch.sum(action_mask, dim=1).to(torch.float32) | ||
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) | ||
|
||
reward = data["reward"].view((-1)) | ||
format_acc = data["format_acc"].view((-1)) | ||
ans_acc = data["ans_acc"].view((-1)) | ||
# if(True): | ||
|
||
if self.adv == "GRPO" or self.adv == "DAPO": | ||
|
||
reward = data["reward"].view((-1)) | ||
format_acc = data["format_acc"].view((-1)) | ||
ans_acc = data["ans_acc"].view((-1)) | ||
|
||
# [minibatch_size, num_generations] | ||
|
||
group_reward = reward.view(-1, self.num_generations) | ||
reward_mean = group_reward.mean(dim=1) | ||
# [minibatch_size x num_generations] | ||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) | ||
|
||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) | ||
# [minibatch_size x num_generations] | ||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) | ||
|
||
# [minibatch_size x num_of_generation] | ||
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() | ||
|
||
elif self.adv == "REINFORCE_PPB": | ||
|
||
reward = data["reward"].view((-1)) | ||
format_acc = data["format_acc"].view((-1)) | ||
ans_acc = data["ans_acc"].view((-1)) | ||
|
||
# [minibatch_size, num_generations] | ||
|
||
group_reward = reward.view(-1, self.num_generations) | ||
reward_mean = group_reward.mean(dim=1) | ||
# [minibatch_size x num_generations] | ||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) | ||
|
||
# [minibatch_size x num_generations] | ||
advantages = ((reward - reward_mean)).unsqueeze(dim=-1) | ||
|
||
advantages_mean = advantages.mean(dim=0) | ||
|
||
advantages_std = advantages.std(dim=0) | ||
|
||
advantages = (advantages - advantages_mean) / (advantages_std + 1e-4) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe consider double-checking the reinforce++ baseline advantage calculation. In reinforce ++, each sample's advantage is calculated by subtracting the mean reward of all generation in the global batch, not per prompt mean |
||
# [minibatch_size x num_of_generation] | ||
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() | ||
|
||
elif self.adv == "RLOO": | ||
reward = data["reward"].view((-1)) | ||
format_acc = data["format_acc"].view((-1)) | ||
ans_acc = data["ans_acc"].view((-1)) | ||
|
||
# [minibatch_size, num_generations] | ||
# [minibatch_size, num_generations] | ||
|
||
group_reward = reward.view(-1, self.num_generations) | ||
reward_mean = group_reward.mean(dim=1) | ||
# [minibatch_size x num_generations] | ||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) | ||
group_reward = reward.view(-1, self.num_generations) | ||
reward_mean = group_reward.mean(dim=1) | ||
# [minibatch_size x num_generations] | ||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) | ||
|
||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) | ||
# [minibatch_size x num_generations] | ||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) | ||
advantages = ( | ||
reward * self.num_generations / (self.num_generations - 1) | ||
- reward_mean * self.num_generations / (self.num_generations - 1) | ||
).unsqueeze(dim=-1) | ||
|
||
# [minibatch_size x num_of_generation] | ||
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() | ||
# [minibatch_size x num_of_generation] | ||
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. may be better to move the common calculations outside of the if statements for conciseness |
||
# filter out overlength samples | ||
if self.filter_truncated_response and action_mask.size(1) == self.max_length: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
4.51.0: qwen2.5 + grpo, qwen3 + grpo, cannot: llama2, llama3.2 | ||
4.47.0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove test log file |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably also consider force num_generation to 1 for reinforce++ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,7 +130,7 @@ | |
) | ||
|
||
# GRPO parameters | ||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) | ||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"]) | ||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.") | ||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.") | ||
parser.add_argument( | ||
|
@@ -227,13 +227,13 @@ | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock | ||
|
||
inference_model_config = dict(path=args.model) | ||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) | ||
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is flash attention not supported? |
||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) | ||
|
||
if args.backend == "transformers": | ||
inference_model_config.update( | ||
dict( | ||
use_flash_attention_2=True, | ||
use_flash_attention_2=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
torch_dtype=torch.bfloat16, | ||
) | ||
) | ||
|
@@ -283,6 +283,7 @@ | |
if args.algo == "GRPO": | ||
# Default Settings | ||
grpo_config = { | ||
"algo": "GRPO", | ||
"lr": args.learning_rate, | ||
"train_microbatch_size": args.train_microbatch_size, | ||
"beta": args.kl_coeff, # KL penalty coefficient | ||
|
@@ -304,6 +305,7 @@ | |
elif args.algo == "DAPO": | ||
# DAPO variant settings | ||
grpo_config = { | ||
"algo": "DAPO", | ||
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch | ||
"lr": args.learning_rate, | ||
"train_microbatch_size": args.train_microbatch_size, | ||
|
@@ -330,6 +332,50 @@ | |
else None | ||
), | ||
} | ||
elif args.algo == "REINFORCE_PPB": | ||
# Default Settings | ||
grpo_config = { | ||
"algo": "REINFORCE_PPB", | ||
"lr": args.learning_rate, | ||
"train_microbatch_size": args.train_microbatch_size, | ||
"beta": args.kl_coeff, # KL penalty coefficient | ||
"loss_variation": "sample_level", | ||
"reward_fn_type": args.reward_type, | ||
"max_length": args.max_new_tokens + args.max_prompt_tokens, | ||
"max_new_tokens": args.max_new_tokens, | ||
"response_format_tags": ( | ||
{ | ||
"think_start": {"text": "<think>", "num_occur": 1}, | ||
"think_end": {"text": "</think>", "num_occur": 1}, | ||
"answer_start": {"text": "<answer>", "num_occur": 1}, | ||
"answer_end": {"text": "</answer>", "num_occur": 1}, | ||
} | ||
if args.reward_type == "think_answer_tags" | ||
else None | ||
), | ||
} | ||
elif args.algo == "RLOO": | ||
# Default Settings | ||
grpo_config = { | ||
"algo": "RLOO", | ||
"lr": args.learning_rate, | ||
"train_microbatch_size": args.train_microbatch_size, | ||
"beta": args.kl_coeff, # KL penalty coefficient | ||
"loss_variation": "sample_level", | ||
"reward_fn_type": args.reward_type, | ||
"max_length": args.max_new_tokens + args.max_prompt_tokens, | ||
"max_new_tokens": args.max_new_tokens, | ||
"response_format_tags": ( | ||
{ | ||
"think_start": {"text": "<think>", "num_occur": 1}, | ||
"think_end": {"text": "</think>", "num_occur": 1}, | ||
"answer_start": {"text": "<answer>", "num_occur": 1}, | ||
"answer_end": {"text": "</answer>", "num_occur": 1}, | ||
} | ||
if args.reward_type == "think_answer_tags" | ||
else None | ||
), | ||
} | ||
else: | ||
raise ValueError(f"Unsupported algorithm: {args.algo}") | ||
if args.reward_type == "code": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the advantages_mean always 0 as advantage is already zero-centered in the previous step?