Skip to content

Add more training models and RLHF algorithms #6368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: grpo-latest
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions applications/ColossalChat/coati/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def apply_chat_template_and_mask(
tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length]
attention_mask = attention_mask[:max_length]

input_ids = torch.tensor(tokens, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
labels = input_ids.clone()
Expand Down
77 changes: 64 additions & 13 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def __init__(
eta_min=0.1 * grpo_config.get("lr", 1e-6),
)

self.adv = grpo_config.get("algo")

def setup(self):
super().setup()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
Expand Down Expand Up @@ -180,23 +182,72 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))

reward = data["reward"].view((-1))
format_acc = data["format_acc"].view((-1))
ans_acc = data["ans_acc"].view((-1))
# if(True):

if self.adv == "GRPO" or self.adv == "DAPO":

reward = data["reward"].view((-1))
format_acc = data["format_acc"].view((-1))
ans_acc = data["ans_acc"].view((-1))

# [minibatch_size, num_generations]

group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [minibatch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)

reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)

# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()

elif self.adv == "REINFORCE_PPB":

reward = data["reward"].view((-1))
format_acc = data["format_acc"].view((-1))
ans_acc = data["ans_acc"].view((-1))

# [minibatch_size, num_generations]

group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [minibatch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)

# [minibatch_size x num_generations]
advantages = ((reward - reward_mean)).unsqueeze(dim=-1)

advantages_mean = advantages.mean(dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the advantages_mean always 0 as advantage is already zero-centered in the previous step?


advantages_std = advantages.std(dim=0)

advantages = (advantages - advantages_mean) / (advantages_std + 1e-4)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe consider double-checking the reinforce++ baseline advantage calculation. In reinforce ++, each sample's advantage is calculated by subtracting the mean reward of all generation in the global batch, not per prompt mean

# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()

elif self.adv == "RLOO":
reward = data["reward"].view((-1))
format_acc = data["format_acc"].view((-1))
ans_acc = data["ans_acc"].view((-1))

# [minibatch_size, num_generations]
# [minibatch_size, num_generations]

group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [minibatch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [minibatch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)

reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
advantages = (
reward * self.num_generations / (self.num_generations - 1)
- reward_mean * self.num_generations / (self.num_generations - 1)
).unsqueeze(dim=-1)

# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be better to move the common calculations outside of the if statements for conciseness

# filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
Expand Down
8 changes: 7 additions & 1 deletion applications/ColossalChat/coati/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer

ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
ALGO_MAP = {
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
"DAPO": GRPOConsumer,
"REINFORCE_PPB": GRPOConsumer,
"RLOO": GRPOConsumer,
}


def get_jsonl_size_fast(path: str) -> int:
Expand Down
3 changes: 3 additions & 0 deletions applications/ColossalChat/coati/distributed/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def __init__(
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
self.tokenizer.padding_side = "left"

if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

# init dataloader
train_dataset_path = train_dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
Expand Down
2 changes: 2 additions & 0 deletions applications/ColossalChat/coati/distributed/untitled.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
4.51.0: qwen2.5 + grpo, qwen3 + grpo, cannot: llama2, llama3.2
4.47.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove test log file

52 changes: 49 additions & 3 deletions applications/ColossalChat/rl_example.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably also consider force num_generation to 1 for reinforce++

Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
)

# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument(
Expand Down Expand Up @@ -227,13 +227,13 @@
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock

inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is flash attention not supported?

generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)

if args.backend == "transformers":
inference_model_config.update(
dict(
use_flash_attention_2=True,
use_flash_attention_2=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

torch_dtype=torch.bfloat16,
)
)
Expand Down Expand Up @@ -283,6 +283,7 @@
if args.algo == "GRPO":
# Default Settings
grpo_config = {
"algo": "GRPO",
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
Expand All @@ -304,6 +305,7 @@
elif args.algo == "DAPO":
# DAPO variant settings
grpo_config = {
"algo": "DAPO",
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
Expand All @@ -330,6 +332,50 @@
else None
),
}
elif args.algo == "REINFORCE_PPB":
# Default Settings
grpo_config = {
"algo": "REINFORCE_PPB",
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
"reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
}
elif args.algo == "RLOO":
# Default Settings
grpo_config = {
"algo": "RLOO",
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
"reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
if args.reward_type == "code":
Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def llama_model_forward(
invert=(sp_mode != "ring_attn"),
)
else:
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
attn_kwargs: torch.Tensor = self._update_causal_mask(
attention_mask, hidden_states, cache_position, None, False
)

# Support SP + PP. Later stages have already received the split input.
split_input = disable_pp or stage_manager.is_first_stage()
Expand Down
Loading