Skip to content

Commit 489f215

Browse files
authored
Merge pull request #6250 from hpcaitech/grpo-latest-dev
[feat] Fix Vllm, logprob, add filtering, temperature annealing, lr descent
2 parents 7795d4c + 2aa7385 commit 489f215

File tree

8 files changed

+239
-42
lines changed

8 files changed

+239
-42
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
5858

5959
self.device = get_current_device()
60+
self.lr_scheduler = None
6061

6162
def setup(self) -> None:
6263
for i in range(self.num_producers):
@@ -121,6 +122,8 @@ def loop(self) -> None:
121122
pbar.set_postfix({"loss": loss})
122123
i += 1
123124
assert len(self.buffer) == 0
125+
if self.lr_scheduler is not None:
126+
self.lr_scheduler.step()
124127
if (step + 1) % self.save_interval == 0:
125128
if self.rank == 0:
126129
print(f"Start saving policy model at step {step + 1}.")

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 172 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import json
2+
import os
13
from contextlib import nullcontext
24
from typing import Optional
35

46
import ray
57
import torch
8+
import torch.distributed as dist
69
import wandb
710
from coati.distributed.consumer import BaseConsumer
811
from coati.distributed.loss import PolicyLoss
@@ -12,6 +15,7 @@
1215
from coati.trainer.utils import all_reduce_mean
1316
from transformers import AutoModelForCausalLM, AutoTokenizer
1417

18+
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
1519
from colossalai.nn.optimizer import HybridAdam
1620

1721

@@ -31,8 +35,10 @@ def __init__(
3135
model_config,
3236
plugin_config,
3337
microbatch_size=1,
34-
num_generations=4,
38+
num_generations=8,
3539
use_wandb=True,
40+
generate_config=None,
41+
training_config={},
3642
):
3743
super().__init__(
3844
num_producers,
@@ -52,7 +58,7 @@ def __init__(
5258
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
5359
self.policy_model.train()
5460
self.policy_model.gradient_checkpointing_enable()
55-
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
61+
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
5662
self.accum_loss = torch.zeros(1, device=self.device)
5763
self.accum_reward = torch.zeros(1, device=self.device)
5864
self.accum_kl = torch.zeros(1, device=self.device)
@@ -61,6 +67,7 @@ def __init__(
6167
self.accum_advantages = torch.zeros(1, device=self.device)
6268
self.accum_response_length = torch.zeros(1, device=self.device)
6369
self.accum_count = 0
70+
self.generate_config = generate_config
6471

6572
# Reference model is initialized from policy model.
6673
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -69,6 +76,9 @@ def __init__(
6976
self.tokenizer = AutoTokenizer.from_pretrained(path)
7077
self.pad_token_id = self.tokenizer.pad_token_id
7178
self.num_generations = num_generations
79+
self.filter_range = training_config.get("filter_range", None)
80+
if self.filter_range is not None:
81+
assert len(self.filter_range) == 2, "Filter range should have 2 values."
7282

7383
# Initialize verifiable reward.
7484
response_format_tags = {
@@ -84,11 +94,21 @@ def __init__(
8494
self.policy_loss_fn = PolicyLoss()
8595
self.global_step = 0
8696
if use_wandb and self.rank == 0:
87-
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
97+
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
98+
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
99+
100+
self.lr_scheduler = CosineAnnealingWarmupLR(
101+
optimizer=self.optimizer,
102+
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
103+
warmup_steps=0,
104+
eta_min=0.1 * training_config.get("lr", 1e-6),
105+
)
88106

89107
def setup(self):
90108
super().setup()
91-
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
109+
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
110+
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
111+
)
92112
self.reference_model, *_ = self.booster.boost(self.reference_model)
93113

94114
def step(self, step_idx: int, **kwargs) -> Optional[float]:
@@ -113,15 +133,17 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
113133
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
114134

115135
need_update = (step_idx + 1) % self.num_microbatches == 0
116-
117136
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
118137
with ctx:
119138
policy_model_logits = self.policy_model(
120139
input_ids=data["input_ids"],
121140
attention_mask=data["attention_mask"],
122141
)["logits"]
123142
action_log_probs = calc_action_log_probs(
124-
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
143+
policy_model_logits / self.generate_config["temperature"],
144+
data["input_ids"],
145+
num_action,
146+
self.plugin.shard_config,
125147
)
126148

127149
with torch.no_grad():
@@ -130,7 +152,10 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
130152
attention_mask=data["attention_mask"],
131153
)["logits"]
132154
reference_action_log_probs = calc_action_log_probs(
133-
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
155+
reference_model_logits / self.generate_config["temperature"],
156+
data["input_ids"],
157+
num_action,
158+
self.plugin.shard_config,
134159
)
135160

136161
per_token_kl = (
@@ -149,21 +174,31 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
149174
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
150175

151176
# [batch_size, num_generations]
177+
152178
group_reward = reward.view(-1, self.num_generations)
179+
reward_mean = group_reward.mean(dim=1)
180+
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
181+
loss_mask = (
182+
None
183+
if self.filter_range is None
184+
else torch.logical_and(
185+
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
186+
).repeat_interleave(self.num_generations, dim=0)
187+
)
153188

154189
# [batch_size x num_generations]
155-
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
190+
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
156191
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
157192
# [batch_size x num_generations]
158193
advantages = (reward - reward_mean) / (reward_std + 1e-4)
159194

160-
# Calculate Loss
161195
loss, skip_update, _ = self.policy_loss_fn(
162196
action_log_probs,
163197
old_action_log_probs,
164198
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
165199
per_token_kl,
166200
action_mask,
201+
loss_mask=loss_mask,
167202
)
168203

169204
if not skip_update:
@@ -207,13 +242,15 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
207242
)
208243
self.wandb_run.log(
209244
{
245+
"metrics/reward": self.accum_reward.item() / self.accum_count,
246+
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
247+
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
248+
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
210249
"train/loss": self.accum_loss.item() / self.accum_count,
211-
"train/reward": self.accum_reward.item() / self.accum_count,
212-
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
213-
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
214250
"train/kl": self.accum_kl.item() / self.accum_count,
215251
"train/advantages": self.accum_advantages.item() / self.accum_count,
216-
"train/response_length": self.accum_response_length.item() / self.accum_count,
252+
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
253+
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
217254
}
218255
)
219256
self.accum_loss.zero_()
@@ -232,3 +269,125 @@ def state_dict(self):
232269
model = self.policy_model.unwrap()
233270
state_dict = model.state_dict()
234271
return state_dict
272+
273+
274+
@ray.remote
275+
class GRPOEvalConsumer(BaseConsumer):
276+
def __init__(
277+
self,
278+
num_producers,
279+
num_episodes,
280+
rank,
281+
world_size,
282+
master_addr,
283+
master_port,
284+
num_update_per_episode,
285+
num_recv_per_update,
286+
batch_size,
287+
model_config,
288+
plugin_config,
289+
microbatch_size=1,
290+
num_generations=4,
291+
use_wandb=True,
292+
log_dir="./results",
293+
):
294+
super().__init__(
295+
num_producers,
296+
num_episodes,
297+
rank,
298+
world_size,
299+
master_addr,
300+
master_port,
301+
num_update_per_episode,
302+
num_recv_per_update,
303+
batch_size,
304+
model_config,
305+
plugin_config,
306+
microbatch_size,
307+
)
308+
path = model_config.pop("path")
309+
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
310+
self.policy_model.train()
311+
self.accum_reward = torch.zeros(1, device=self.device)
312+
self.accum_format_reward = torch.zeros(1, device=self.device)
313+
self.accum_acc_reward = torch.zeros(1, device=self.device)
314+
self.accum_response_length = torch.zeros(1, device=self.device)
315+
self.accum_count = torch.zeros(1, device=self.device)
316+
317+
self.tokenizer = AutoTokenizer.from_pretrained(path)
318+
self.pad_token_id = self.tokenizer.pad_token_id
319+
self.num_generations = num_generations
320+
321+
# Initialize verifiable reward.
322+
response_format_tags = {
323+
"think_start": {"text": "<think>", "num_occur": 1},
324+
"think_end": {"text": "</think>", "num_occur": 1},
325+
"answer_start": {"text": "<answer>", "num_occur": 1},
326+
"answer_end": {"text": "</answer>", "num_occur": 1},
327+
}
328+
self.reward_model = VerifiableReward(
329+
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
330+
)
331+
332+
self.log_dir = log_dir
333+
if not os.path.exists(self.log_dir):
334+
os.makedirs(self.log_dir)
335+
else:
336+
os.system(f"rm -rf {self.log_dir}/*")
337+
338+
def setup(self):
339+
super().setup()
340+
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
341+
342+
def step(self, step_idx: int, **kwargs) -> Optional[float]:
343+
rank = dist.get_rank()
344+
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
345+
kwargs["input_ids"].size(0)
346+
reward_group = self.reward_model(
347+
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
348+
)
349+
reward = [value[0].item() for value in reward_group]
350+
format_reward = [value[1].item() for value in reward_group]
351+
acc_reward = [value[2].item() for value in reward_group]
352+
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
353+
354+
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
355+
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
356+
for i in range(len(response)):
357+
f.write(
358+
json.dumps(
359+
{
360+
"response": response[i],
361+
"reward": reward[i],
362+
"format_reward": format_reward[i],
363+
"acc_reward": acc_reward[i],
364+
"response_length": response_length[i],
365+
},
366+
ensure_ascii=False,
367+
)
368+
+ "\n"
369+
)
370+
371+
self.accum_reward += sum(reward)
372+
self.accum_format_reward += sum(format_reward)
373+
self.accum_acc_reward += sum(acc_reward)
374+
self.accum_response_length += sum(response_length)
375+
self.accum_count += len(reward)
376+
377+
# print results
378+
total_count = all_reduce_mean(self.accum_count, self.plugin)
379+
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
380+
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
381+
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
382+
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
383+
if rank == 0:
384+
print(
385+
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
386+
)
387+
return None
388+
389+
def state_dict(self):
390+
self.policy_model._force_wait_all_gather()
391+
model = self.policy_model.unwrap()
392+
state_dict = model.state_dict()
393+
return state_dict

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,21 @@ class TransformersInferenceBackend(BaseInferenceBackend):
5353
)
5454
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
5555

56-
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
56+
def __init__(
57+
self,
58+
model_config: Dict[str, Any],
59+
generate_config: Dict[str, Any],
60+
tokenizer: PreTrainedTokenizer,
61+
num_generations: int = 8,
62+
):
5763
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
5864
model_config.update(self.FORCE_MODEL_CONFIG)
5965
path = model_config.pop("path")
6066
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config)
6167
self.generate_config = generate_config.copy()
6268
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
6369
self.tokenizer = tokenizer
64-
self.num_generations = 8
70+
self.num_generations = num_generations
6571

6672
@torch.no_grad()
6773
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
@@ -120,7 +126,13 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
120126

121127

122128
class SGLangInferenceBackend(BaseInferenceBackend):
123-
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
129+
def __init__(
130+
self,
131+
model_config: Dict[str, Any],
132+
generate_config: Dict[str, Any],
133+
tokenizer: PreTrainedTokenizer,
134+
num_generations: int = 8,
135+
):
124136
if sgl is None:
125137
raise ImportError("sglang is not installed")
126138
path = model_config.pop("path")
@@ -175,27 +187,38 @@ class VLLMInferenceBackend(BaseInferenceBackend):
175187
)
176188
FORCE_GENERATE_CONFIG = dict(
177189
logprobs=0,
178-
n=8,
179190
)
180191

181-
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
192+
def __init__(
193+
self,
194+
model_config: Dict[str, Any],
195+
generate_config: Dict[str, Any],
196+
tokenizer: PreTrainedTokenizer,
197+
num_generations: int = 8,
198+
):
182199
if LLM is None:
183200
raise ImportError("vllm is not installed")
184201
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
185202
path = model_config.pop("path")
186-
self.llm = LLM(path, **model_config)
203+
self.llm = LLM(model=path, **model_config)
187204
generate_config = generate_config.copy()
188205
generate_config.update(self.FORCE_GENERATE_CONFIG)
206+
generate_config.update({"n": num_generations})
189207
self.generate_config = SamplingParams(**generate_config)
190208
self.tokenizer = tokenizer
191-
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
209+
self.num_generations = num_generations
192210

193211
@torch.no_grad()
194212
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
195213
micro_batch_size = input_ids.size(0)
196214
response_start_idx = input_ids.size(1)
215+
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
216+
micro_batch_input_ids = input_ids.tolist()
217+
micro_batch_input_ids_no_padding = [
218+
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
219+
]
197220
outputs = self.llm.generate(
198-
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
221+
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
199222
)
200223
out_tokens = []
201224
out_len = []

0 commit comments

Comments
 (0)