Skip to content

Commit b96d690

Browse files
author
Tong Li
committed
grpo consumer
1 parent c15225b commit b96d690

File tree

1 file changed

+26
-30
lines changed

1 file changed

+26
-30
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,11 @@ def __init__(
5252
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
5353
self.policy_model.train()
5454
self.policy_model.gradient_checkpointing_enable()
55-
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4)
55+
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
5656
self.accum_loss = torch.zeros(1, device=self.device)
5757
self.accum_reward = torch.zeros(1, device=self.device)
5858
self.accum_kl = torch.zeros(1, device=self.device)
59+
self.accum_count = 0
5960

6061
# Reference model is initialized from policy model.
6162
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -79,13 +80,7 @@ def __init__(
7980
self.policy_loss_fn = PolicyLoss()
8081
self.global_step = 0
8182
if self.rank == 0:
82-
self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True)
83-
# import os
84-
# import time
85-
86-
# log_dir = self.wandb_run.dir
87-
# # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
88-
# # self.writer = SummaryWriter(log_dir=log_dir)
83+
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
8984

9085
def setup(self):
9186
super().setup()
@@ -129,66 +124,67 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
129124
)["logits"]
130125
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
131126

132-
# GRPO advantage calculation
133-
kl = torch.sum(-0.1 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
134-
action_mask, dim=-1
127+
per_token_kl = (
128+
torch.exp(reference_action_log_probs - action_log_probs)
129+
- (reference_action_log_probs - action_log_probs)
130+
- 1
135131
)
132+
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
136133

137134
reward = self.reward_model(
138135
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
139136
)
140-
reward = kl + reward
141137
# [batch_size, num_generations]
142138
group_reward = reward.view(-1, self.num_generations)
143139

144140
# [batch_size x num_generations]
145141
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
146142
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
147143
# [batch_size x num_generations]
148-
advantages = (group_reward.view(-1) - reward_mean) / (reward_std + 1e-4)
149-
150-
# GRPO advantage calculation
151-
kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
152-
action_mask, dim=-1
153-
)
144+
advantages = (reward - reward_mean) / (reward_std + 1e-4)
154145

155146
# Calculate Loss
156147
loss, skip_update, _ = self.policy_loss_fn(
157148
action_log_probs,
158149
old_action_log_probs,
159150
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
151+
per_token_kl,
160152
action_mask,
161153
)
162154

163-
loss = loss / self.num_microbatches
164155
if not skip_update:
165156
self.booster.backward(loss, self.optimizer)
166-
loss = all_reduce_mean(loss)
167-
reward = all_reduce_mean(reward.mean())
168-
kl = all_reduce_mean(kl.mean())
157+
loss = all_reduce_mean(loss, self.plugin)
158+
reward = all_reduce_mean(reward.mean(), self.plugin)
159+
kl = all_reduce_mean(kl.mean(), self.plugin)
169160
self.accum_loss.add_(loss.data)
170161
self.accum_reward.add_(reward.data)
171162
self.accum_kl.add_(kl.data)
163+
self.accum_count += 1
172164
if need_update:
173165
self.optimizer.step()
174166
self.optimizer.zero_grad()
175167
loss_scalar = self.accum_loss.item()
176168
if self.rank == 0:
177-
print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item())
169+
print(
170+
"Loss:",
171+
self.accum_loss.item() / self.accum_count,
172+
"Reward:",
173+
self.accum_reward.item() / self.accum_count,
174+
"KL:",
175+
self.accum_kl.item() / self.accum_count,
176+
)
178177
self.wandb_run.log(
179178
{
180-
"train/loss": self.accum_loss.item(),
181-
"train/reward": self.accum_reward.item(),
182-
"train/kl": self.accum_kl.item(),
179+
"train/loss": self.accum_loss.item() / self.accum_count,
180+
"train/reward": self.accum_reward.item() / self.accum_count,
181+
"train/kl": self.accum_kl.item() / self.accum_count,
183182
}
184183
)
185-
# self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step)
186-
# self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step)
187-
# self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step)
188-
# self.global_step += 1
189184
self.accum_loss.zero_()
190185
self.accum_reward.zero_()
191186
self.accum_kl.zero_()
187+
self.accum_count = 0
192188
return loss_scalar
193189

194190
def state_dict(self):

0 commit comments

Comments
 (0)