Skip to content

Commit 3766338

Browse files
committed
fix metric calculation
1 parent 116621d commit 3766338

File tree

4 files changed

+147
-48
lines changed

4 files changed

+147
-48
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,24 +120,85 @@ def loop(self) -> None:
120120
raw_batch = unbind_batch(
121121
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
122122
)
123-
processed_batch = [
124-
self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
125-
]
126-
filtered_batch = [t for t in processed_batch if t is not None]
123+
recv_effective_count = 0
124+
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
125+
# we need to calculate the metrics before filtering here for logging
126+
for group in raw_batch:
127+
group_with_reward = self.calculate_group_reward(group)
128+
group_reward_mean = group_with_reward["reward"].mean().cpu().item()
129+
group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item()
130+
group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item()
131+
group_response_len = (
132+
(
133+
group_with_reward["response_idx"][:, 1]
134+
- group_with_reward["response_idx"][:, 0]
135+
+ 1
136+
)
137+
.type(torch.float32)
138+
.mean()
139+
.cpu()
140+
.item()
141+
)
142+
filtered_group = self.prompt_level_filtering(group_with_reward)
143+
recv_effective_count += 1 if filtered_group is not None else 0
144+
self.buffer.append(
145+
[
146+
filtered_group,
147+
group_reward_mean,
148+
group_format_acc_mean,
149+
group_ans_acc_mean,
150+
group_response_len,
151+
]
152+
)
127153
if self.filter_range is not None:
128154
print(
129-
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
155+
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {recv_effective_count}"
130156
)
157+
# mapping the effective group to the raw group for indexing
158+
effective_group_to_raw_group_mapping = {}
159+
for buffer_idx in range(len(self.buffer)):
160+
if self.buffer[buffer_idx][0] is not None:
161+
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
162+
buffer_idx
163+
)
131164

132-
self.buffer.extend(filtered_batch)
133-
while len(self.buffer) >= self.dp_size * self.minibatch_size:
134-
batches = self.buffer[
135-
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
165+
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
166+
# on each dp_rank, we use minibatch_size effective samples to form a batch
167+
batches = [
168+
self.buffer[effective_group_to_raw_group_mapping[i]]
169+
for i in range(
170+
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
171+
)
136172
]
137-
batch = bind_batch(batches)
173+
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
174+
# each mini-batch use the first self.dp_size * minibatch_size effective samples
175+
raw_mini_batches = self.buffer[
176+
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
177+
] # include the last effective sample
178+
raw_mini_batches_metric_dict = {
179+
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
180+
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
181+
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
182+
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
183+
}
184+
batch = bind_batch([t[0] for t in batches])
138185
batch = post_recv(batch)
139-
loss = self.step(i, pbar, **batch)
140-
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
186+
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
187+
self.buffer = self.buffer[
188+
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
189+
]
190+
# recalculate the effective group to raw group mapping
191+
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
192+
effective_group_to_raw_group_mapping = {}
193+
for buffer_idx in range(len(self.buffer)):
194+
if self.buffer[buffer_idx][0] is not None:
195+
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
196+
buffer_idx
197+
)
198+
assert (
199+
len(effective_group_to_raw_group_mapping)
200+
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
201+
)
141202
if loss is not None:
142203
pbar.set_postfix({"loss": loss})
143204
i += 1

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ def __init__(
7272
self.policy_model.gradient_checkpointing_enable()
7373
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
7474
self.accum_loss = torch.zeros(1, device=self.device)
75-
self.accum_reward = torch.zeros(1, device=self.device)
7675
self.accum_kl = torch.zeros(1, device=self.device)
77-
self.accum_format_acc = torch.zeros(1, device=self.device)
78-
self.accum_ans_acc = torch.zeros(1, device=self.device)
7976
self.accum_advantages = torch.zeros(1, device=self.device)
80-
self.accum_response_length = torch.zeros(1, device=self.device)
77+
self.raw_train_batch_reward = []
78+
self.raw_train_batch_format_acc = []
79+
self.raw_train_batch_ans_acc = []
80+
self.raw_train_batch_response_len = []
8181
self.accum_count = 0
8282
self.generate_config = generate_config
8383
self.grpo_config = grpo_config
@@ -186,7 +186,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
186186
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
187187
"""
188188
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
189-
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
189+
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
190+
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
191+
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
192+
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
193+
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
190194
action_mask = data["action_mask"]
191195
num_action = action_mask.shape[1]
192196
old_action_log_probs = data["action_log_probs"]
@@ -430,11 +434,7 @@ def _criterion(outputs, inputs):
430434
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
431435
if self.policy_loss_fn.beta > 0:
432436
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
433-
self.accum_reward.add_(reward.data)
434-
self.accum_format_acc.add_(format_acc.data)
435-
self.accum_ans_acc.add_(ans_acc.data)
436437
self.accum_advantages.add_(advantages.data)
437-
self.accum_response_length.add_(response_length.data)
438438
self.accum_count += 1
439439
if need_update:
440440
self.optimizer.step()
@@ -452,21 +452,33 @@ def _criterion(outputs, inputs):
452452
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
453453
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
454454
):
455+
raw_batch_reward_mean = sum(self.raw_train_batch_reward) / len(self.raw_train_batch_reward)
456+
raw_batch_format_acc_mean = sum(self.raw_train_batch_format_acc) / len(
457+
self.raw_train_batch_format_acc
458+
)
459+
raw_batch_ans_acc_mean = sum(self.raw_train_batch_ans_acc) / len(self.raw_train_batch_ans_acc)
460+
raw_batch_response_len_mean = sum(self.raw_train_batch_response_len) / len(
461+
self.raw_train_batch_response_len
462+
)
463+
self.raw_train_batch_reward = []
464+
self.raw_train_batch_format_acc = []
465+
self.raw_train_batch_ans_acc = []
466+
self.raw_train_batch_response_len = []
455467
to_log_msg = [
456468
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
457-
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
458-
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
459-
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
469+
f"Reward: {raw_batch_reward_mean:.4f}",
470+
f"format Reward: {raw_batch_format_acc_mean:.4f}",
471+
f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
460472
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
461-
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
473+
f"Response Length: {raw_batch_response_len_mean:.4f}",
462474
f"Sample_utilization: {sample_utilization:.4f}",
463475
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
464476
print("\n".join(to_log_msg))
465477
metrics = {
466-
"metrics/reward": self.accum_reward.item() / self.accum_count,
467-
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
468-
"metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count,
469-
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
478+
"metrics/reward": raw_batch_reward_mean,
479+
"metrics/format_acc": raw_batch_format_acc_mean,
480+
"metrics/ans_acc": raw_batch_ans_acc_mean,
481+
"metrics/response_length": raw_batch_response_len_mean,
470482
"train/loss": self.accum_loss.item() / self.accum_count,
471483
"train/advantages": self.accum_advantages.item() / self.accum_count,
472484
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
@@ -478,12 +490,8 @@ def _criterion(outputs, inputs):
478490
if self.wandb_run is not None:
479491
self.wandb_run.log(metrics)
480492
self.accum_loss.zero_()
481-
self.accum_reward.zero_()
482-
self.accum_ans_acc.zero_()
483-
self.accum_format_acc.zero_()
484493
self.accum_kl.zero_()
485494
self.accum_advantages.zero_()
486-
self.accum_response_length.zero_()
487495
self.accum_count = 0
488496
return loss_scalar
489497
else:

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import os
23
import uuid
34
from typing import Any, Dict, Optional
45

@@ -56,7 +57,7 @@ def launch_distributed(
5657
eval_save_dir: Optional[str] = None,
5758
eval_generation_config: Optional[Dict[str, Any]] = None,
5859
log_rollout_interval: int = 20,
59-
rollout_log_file: str = "./rollout_log.jsonl",
60+
rollout_save_dir: str = "./rollout",
6061
):
6162
if core_algo not in ALGO_MAP:
6263
raise NotImplementedError(f"{core_algo} is not supported yet.")
@@ -74,6 +75,10 @@ def launch_distributed(
7475

7576
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
7677
wandb_group_name = str(uuid.uuid4())
78+
rollout_log_file = os.path.join(
79+
rollout_save_dir,
80+
f"{project_name}_run_{wandb_group_name}.jsonl",
81+
)
7782

7883
procs = []
7984
for i in range(num_producers):

applications/ColossalChat/rl_example.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,34 @@
121121
parser.add_argument(
122122
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
123123
)
124+
parser.add_argument(
125+
"-tp",
126+
"--tensor-parallel-size",
127+
type=int,
128+
default=1,
129+
help="Tensor parallel size for the inference backend. Please check the generation arguments documentation for your backend.",
130+
)
131+
parser.add_argument(
132+
"-pp",
133+
"--pipeline-parallel-size",
134+
type=int,
135+
default=1,
136+
help="Pipeline parallel size for the inference backend. Please check the generation arguments documentation for your backend.",
137+
)
138+
parser.add_argument(
139+
"-zero",
140+
"--zero-stage",
141+
type=int,
142+
default=0,
143+
help="Zero stage for the inference backend. Please check the generation arguments documentation for your backend.",
144+
)
145+
parser.add_argument(
146+
"-ptp",
147+
"--produce-tensor-parallel-size",
148+
type=int,
149+
default=1,
150+
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
151+
)
124152
args = parser.parse_args()
125153

126154
if args.train_minibatch_size is None:
@@ -178,7 +206,7 @@
178206
enforce_eager=True,
179207
enable_chunked_prefill=True,
180208
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
181-
tensor_parallel_size=1,
209+
tensor_parallel_size=args.produce_tensor_parallel_size,
182210
)
183211
)
184212
generate_config.update(
@@ -228,7 +256,7 @@
228256

229257
launch_distributed(
230258
num_producers=args.num_inferencer,
231-
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
259+
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.produce_tensor_parallel_size),
232260
num_consumer_procs=args.num_trainers,
233261
num_episodes=args.num_episodes,
234262
inference_batch_size=args.inference_batch_size,
@@ -247,17 +275,14 @@
247275
train_model_config=train_model_config,
248276
grpo_config=grpo_config,
249277
plugin_config={
250-
"zero_stage": 2,
251-
}, # for zero
252-
# plugin_config={
253-
# "tp_size": 2,
254-
# "pp_size": 2,
255-
# "microbatch_size": max(
256-
# 1, args.train_microbatch_size // 2
257-
# ), # microbatch size should be set to train_microbatch_size // pp_size
258-
# "zero_stage": 0,
259-
# "max_norm": 1.0,
260-
# }, # for pp, tp
278+
"tp_size": args.tensor_parallel_size,
279+
"pp_size": args.pipeline_parallel_size,
280+
"microbatch_size": max(
281+
1, args.train_microbatch_size // args.pipeline_parallel_size
282+
), # microbatch size should be set to train_microbatch_size // pp_size
283+
"zero_stage": args.zero_stage,
284+
"max_norm": 1.0,
285+
}, # for pp, tp
261286
inference_backend=args.backend,
262287
master_addr="localhost",
263288
master_port=args.master_port,
@@ -273,5 +298,5 @@
273298
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
274299
eval_generation_config=eval_generation_config,
275300
log_rollout_interval=20,
276-
rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"),
301+
rollout_save_dir=args.rollout_save_dir,
277302
)

0 commit comments

Comments
 (0)