Skip to content

Commit 3c42c0c

Browse files
authored
Merge pull request #6309 from hpcaitech/grpo-eval-dev
[feat] Support evaluation during training
2 parents ab95624 + 021914c commit 3c42c0c

File tree

9 files changed

+312
-87
lines changed

9 files changed

+312
-87
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,4 @@ applications/ColossalChat/logs
165165
applications/ColossalChat/tests/logs
166166
applications/ColossalChat/wandb
167167
applications/ColossalChat/model
168+
applications/ColossalChat/eval

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def setup(self) -> None:
9393
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
9494

9595
self.buffer = []
96-
9796
self.recv_cnt = 0
9897

9998
def state_dict(self) -> Dict[str, torch.Tensor]:
@@ -209,6 +208,8 @@ def __init__(
209208
model_config,
210209
plugin_config,
211210
minibatch_size,
211+
save_interval,
212+
save_dir,
212213
)
213214
path = model_config.pop("path")
214215
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ def __init__(
3333
plugin_config,
3434
minibatch_size=1,
3535
num_generations=8,
36-
use_wandb=True,
3736
generate_config=None,
3837
grpo_config={},
39-
project_name=None,
4038
save_interval: int = 100,
4139
save_dir="./model",
40+
project_name: str = None,
41+
run_name: str = None,
42+
wandb_group_name: str = None,
4243
):
4344
print(f"Using GRPO config: {grpo_config}")
4445
if (
@@ -84,6 +85,9 @@ def __init__(
8485
self.effective_sample_count = 0
8586
self.effective_prompt_count = 0
8687
self.total_sample_count = 0
88+
self.project_name = project_name
89+
self.run_name = run_name
90+
self.wandb_group_name = wandb_group_name
8791

8892
self.policy_loss_fn = PolicyLoss(
8993
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
@@ -134,7 +138,6 @@ def __init__(
134138
**reward_model_kwargs,
135139
)
136140
self.global_step = 0
137-
self.use_wandb = use_wandb
138141

139142
self.lr_scheduler = CosineAnnealingWarmupLR(
140143
optimizer=self.optimizer,
@@ -145,13 +148,16 @@ def __init__(
145148

146149
def setup(self):
147150
super().setup()
148-
if self.use_wandb and (
149-
(not self.plugin.pp_size > 1 and self.rank == 0)
150-
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
151+
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
152+
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
151153
):
152-
# Initialize wandb.
153-
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
154-
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
154+
self.wandb_run = wandb.init(
155+
project=self.project_name,
156+
sync_tensorboard=False,
157+
dir="./wandb",
158+
name=self.run_name,
159+
group=self.wandb_group_name,
160+
)
155161

156162
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
157163
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
@@ -265,7 +271,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
265271
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
266272
self.effective_sample_count += effective_samples.item()
267273
self.total_sample_count += total_samples.item()
268-
269274
pbar.set_postfix(
270275
{
271276
"Global Step": self.global_step,
@@ -506,8 +511,8 @@ def _criterion(outputs, inputs):
506511
}
507512
if self.policy_loss_fn.beta > 0:
508513
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
509-
510-
self.wandb_run.log(metrics)
514+
if self.wandb_run is not None:
515+
self.wandb_run.log(metrics)
511516
self.accum_loss.zero_()
512517
self.accum_reward.zero_()
513518
self.accum_ans_acc.zero_()
@@ -521,7 +526,6 @@ def _criterion(outputs, inputs):
521526
# All gather excessive prompts index across DP ranks.
522527
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
523528
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
524-
525529
return loss_scalar, excessive_prompts_idx
526530
else:
527531
return None, excessive_prompts_idx
@@ -530,4 +534,5 @@ def state_dict(self):
530534
self.policy_model._force_wait_all_gather()
531535
model = self.policy_model.unwrap()
532536
state_dict = model.state_dict()
537+
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
533538
return state_dict

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ def __init__(
205205
generate_config = generate_config.copy()
206206
generate_config.update(self.FORCE_GENERATE_CONFIG)
207207
generate_config.update({"n": num_generations})
208-
self.generate_config = SamplingParams(**generate_config)
208+
self.generate_config = generate_config
209+
self.sample_params = SamplingParams(**generate_config)
209210
self.model_config = model_config
210211
self.tokenizer = tokenizer
211212
self.num_generations = num_generations
@@ -219,8 +220,9 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
219220
micro_batch_input_ids_no_padding = [
220221
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
221222
]
223+
sample_params = kwargs.get("sample_params", self.sample_params)
222224
outputs = self.llm.generate(
223-
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
225+
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
224226
)
225227
out_tokens = []
226228
out_len = []
@@ -236,7 +238,7 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
236238
log_probs.append(p)
237239

238240
# pad them
239-
max_len = self.generate_config.max_tokens
241+
max_len = self.sample_params.max_tokens
240242
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
241243

242244
for i, new_token_ids in enumerate(out_tokens):
@@ -266,11 +268,11 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
266268
"response_idx": response_idx,
267269
}
268270

269-
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
271+
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
270272

271273
if "gt_answer" in kwargs:
272274
# repeat gt_answer for each prompt.
273-
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
275+
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
274276
data = {k: v.to(get_current_device()) for k, v in data.items()}
275277
return data
276278

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import uuid
23
from typing import Any, Dict, Optional
34

45
import ray
@@ -34,7 +35,7 @@ def launch_distributed(
3435
inference_microbatch_size: int,
3536
train_batch_size: int,
3637
train_minibatch_size: int,
37-
dataset_config: Dict[str, Any],
38+
train_dataset_config: Dict[str, Any],
3839
dataloaders_config: Dict[str, Any],
3940
inference_model_config: Dict[str, Any],
4041
generate_config: Dict[str, Any],
@@ -50,8 +51,11 @@ def launch_distributed(
5051
project_name: Optional[str] = None,
5152
save_interval: int = 100,
5253
save_dir: str = "./model",
54+
eval_dataset_config: Optional[Dict[str, Any]] = None,
55+
eval_interval: int = 100,
56+
eval_save_dir: Optional[str] = None,
57+
eval_generation_config: Optional[Dict[str, Any]] = None,
5358
):
54-
5559
if core_algo not in ALGO_MAP:
5660
raise NotImplementedError(f"{core_algo} is not supported yet.")
5761
else:
@@ -60,12 +64,15 @@ def launch_distributed(
6064
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
6165
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
6266

63-
dataset_path = dataset_config["path"]
67+
dataset_path = train_dataset_config["path"]
6468
num_samples = get_jsonl_size_fast(dataset_path)
6569
global_inference_batch_size = inference_batch_size * num_producers
6670
num_update_per_episode = num_samples // global_inference_batch_size
6771
num_recv_per_update = inference_batch_size // inference_microbatch_size
6872

73+
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
74+
wandb_group_name = str(uuid.uuid4())
75+
6976
procs = []
7077
for i in range(num_producers):
7178
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
@@ -74,7 +81,7 @@ def launch_distributed(
7481
num_consumer_procs=num_consumer_procs,
7582
num_episodes=num_episodes,
7683
batch_size=inference_batch_size,
77-
dataset_config=dataset_config,
84+
train_dataset_config=train_dataset_config,
7885
dataloaders_config=dataloaders_config,
7986
model_config=inference_model_config,
8087
generate_config=generate_config,
@@ -83,6 +90,14 @@ def launch_distributed(
8390
backend=inference_backend,
8491
num_generations=num_generations,
8592
consumer_plugin_config=plugin_config,
93+
eval_dataset_config=eval_dataset_config,
94+
eval_interval=eval_interval,
95+
evaluation_function_type=grpo_config["reward_fn_type"],
96+
eval_save_dir=eval_save_dir,
97+
eval_generation_config=eval_generation_config,
98+
project_name=project_name,
99+
run_name=run_name,
100+
wandb_group_name=wandb_group_name,
86101
)
87102
procs.append(producer)
88103
generate_config_consumer = copy.deepcopy(generate_config)
@@ -108,9 +123,11 @@ def launch_distributed(
108123
generate_config=generate_config_consumer,
109124
grpo_config=grpo_config,
110125
num_generations=num_generations,
111-
project_name=project_name,
112126
save_interval=save_interval,
113127
save_dir=save_dir,
128+
project_name=project_name,
129+
run_name=run_name,
130+
wandb_group_name=wandb_group_name,
114131
)
115132
procs.append(consumer)
116133
ray.get([p.setup.remote() for p in procs])

0 commit comments

Comments
 (0)