Skip to content

Commit 094f119

Browse files
committed
merge
2 parents 50070c1 + aca5476 commit 094f119

File tree

4 files changed

+121
-91
lines changed

4 files changed

+121
-91
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,14 @@ def loop(self) -> None:
106106
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
107107
)
108108
for episode in range(self.num_episodes):
109-
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
109+
with tqdm(
110+
range(self.num_update_per_episode),
111+
desc=f"Episode {episode} with rollout step(s)",
112+
disable=self.rank != 0,
113+
) as pbar:
110114
for step in pbar:
111115
i = 0
116+
allow_sync_model = False
112117
for _ in range(self.num_recv_per_update):
113118
# receive data from producers
114119
for r in range(self.num_producers):
@@ -126,15 +131,15 @@ def loop(self) -> None:
126131
]
127132
batch = bind_batch(batches)
128133
batch = post_recv(batch)
129-
loss, num_excessive_prompts = self.step(i, pbar, **batch)
130-
self.buffer = (
131-
self.buffer[
132-
(self.dp_rank + 1) * self.minibatch_size
133-
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
134-
]
135-
+ self.buffer[self.dp_size * self.minibatch_size :]
136-
)
134+
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
135+
136+
if excessive_prompts_idx is not None:
137+
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
138+
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
139+
else:
140+
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
137141
if loss is not None:
142+
allow_sync_model = True
138143
pbar.set_postfix({"loss": loss})
139144
i += 1
140145
if self.lr_scheduler is not None:
@@ -148,29 +153,31 @@ def loop(self) -> None:
148153
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
149154

150155
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
151-
if self.pp_size > 1:
152-
print(
153-
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
154-
)
155-
else:
156-
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
157-
torch.cuda.empty_cache()
158-
state_dict = self.state_dict()
159-
if self.pp_size > 1:
160-
if self.tp_rank == 0 and self.dp_rank == 0:
161-
ray_broadcast_tensor_dict(
162-
state_dict,
163-
src=self.num_producers,
164-
device=self.device,
165-
group_name=f"sync_model_{self.pp_rank}",
166-
)
167-
else:
168-
if self.rank == 0:
169-
ray_broadcast_tensor_dict(
170-
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
156+
if allow_sync_model:
157+
if self.pp_size > 1:
158+
print(
159+
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
171160
)
172-
del state_dict
173-
torch.cuda.empty_cache()
161+
else:
162+
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
163+
torch.cuda.empty_cache()
164+
state_dict = self.state_dict()
165+
if self.pp_size > 1:
166+
if self.tp_rank == 0 and self.dp_rank == 0:
167+
ray_broadcast_tensor_dict(
168+
state_dict,
169+
src=self.num_producers,
170+
device=self.device,
171+
group_name=f"sync_model_{self.pp_rank}",
172+
)
173+
else:
174+
if self.rank == 0:
175+
ray_broadcast_tensor_dict(
176+
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
177+
)
178+
del state_dict
179+
torch.cuda.empty_cache()
180+
allow_sync_model = False
174181

175182

176183
@ray.remote

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 52 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from contextlib import nullcontext
32
from typing import Any, Optional
43

@@ -10,7 +9,7 @@
109
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
1110
from coati.distributed.reward.verifiable_reward import VerifiableReward
1211
from coati.distributed.utils import calc_action_log_probs
13-
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
12+
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
1413
from transformers import AutoModelForCausalLM, AutoTokenizer
1514

1615
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -43,13 +42,6 @@ def __init__(
4342
wandb_group_name: str = None,
4443
):
4544
print(f"Using GRPO config: {grpo_config}")
46-
if grpo_config.get("loss_variation", "sample_level") == "token_level":
47-
if batch_size != minibatch_size:
48-
warnings.warn(
49-
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
50-
UserWarning,
51-
)
52-
minibatch_size = batch_size
5345
if (
5446
plugin_config.get("pp_size", 1) > 1
5547
and "num_microbatches" not in plugin_config
@@ -91,6 +83,7 @@ def __init__(
9183
self.grpo_config = grpo_config
9284
self.project_name = project_name
9385
self.effective_sample_count = 0
86+
self.effective_prompt_count = 0
9487
self.total_sample_count = 0
9588
self.project_name = project_name
9689
self.run_name = run_name
@@ -219,70 +212,66 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
219212
group_ans_acc = (
220213
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
221214
)
215+
# [minibatch_size x num_of_generation]
222216
loss_mask = (
223217
torch.ones(action_mask.size(0), device=action_mask.device).bool()
224218
if self.filter_range is None
225219
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
226220
)
221+
227222
# filter out overlength samples
228223
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
229224
loss_mask = torch.logical_and(
230225
loss_mask,
231226
action_mask[:, -1] == False,
232227
)
233-
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
234-
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
235-
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
236-
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
237-
self.effective_sample_count += effective_samples.item()
238-
self.total_sample_count += total_samples.item()
228+
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
229+
230+
# [minibatch_size] -> calculate the number of effective prompts
231+
effective_prompts_mask = prompt_level_mask.any(dim=1)
232+
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
233+
self.effective_prompt_count += effective_prompts.item()
234+
excessive_prompts_idx = None
239235

240236
mean_kl, mean_loss = [], []
241237

242238
if self.grpo_config.get("dynamic_batching", True):
243-
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
244-
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
245-
num_excessive_samples = (
246-
int(
247-
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
248-
/ self.num_generations
249-
/ self.dp_size
250-
)
251-
* self.num_generations
252-
)
253-
if num_excessive_samples > 0:
254-
data = {
255-
k: (
256-
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
257-
if k
258-
in [
259-
"input_ids",
260-
"attention_mask",
261-
"action_log_probs",
262-
"action_mask",
263-
"response_idx",
264-
"gt_answer",
265-
]
266-
else v
267-
)
268-
for k, v in data.items()
269-
}
270-
action_mask = action_mask[
271-
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
272-
]
273-
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
274-
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
275-
else:
276-
num_excessive_samples = 0
239+
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
240+
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
241+
242+
if excessive_prompts > 0:
243+
excessive_prompts_per_rank = excessive_prompts // self.dp_size
244+
# Only count excessive prompts if they are greater than 1 per rank.
245+
# TODO: customize excessive prompts calculation.
246+
if excessive_prompts_per_rank != 0:
247+
# Mask excessive prompts to False
248+
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
249+
if excessive_prompts_per_rank <= len(true_indices):
250+
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
251+
else:
252+
excessive_prompts_idx = true_indices
253+
effective_prompts_mask[excessive_prompts_idx] = False
254+
255+
for mask_idx in range(len(effective_prompts_mask)):
256+
if effective_prompts_mask[mask_idx] == False:
257+
# Update loss mask.
258+
loss_mask[mask_idx] = False
277259
else:
278260
# If dynamic batching is disabled, we need to use all samples for training.
279261
need_update = (step_idx + 1) % self.num_microbatches == 0
280-
num_excessive_samples = 0
262+
263+
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
264+
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
265+
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
266+
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
267+
self.effective_sample_count += effective_samples.item()
268+
self.total_sample_count += total_samples.item()
281269

282270
pbar.set_postfix(
283271
{
284-
"Step": self.global_step + 1,
285-
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
272+
"Global Step": self.global_step,
273+
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
274+
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
286275
}
287276
)
288277

@@ -381,7 +370,7 @@ def _criterion(outputs, inputs):
381370
kl.append(appox_kl.mean())
382371
else:
383372
per_token_kl = 0.0
384-
kl.append(0.0)
373+
kl.append(torch.tensor(0.0))
385374

386375
loss, _ = self.policy_loss_fn(
387376
action_log_probs,
@@ -485,6 +474,7 @@ def _criterion(outputs, inputs):
485474
self.optimizer.zero_grad()
486475
self.global_step += 1
487476
sample_utilization = self.effective_sample_count / self.total_sample_count
477+
self.effective_prompt_count = 0
488478
self.effective_sample_count = 0
489479
self.total_sample_count = 0
490480
loss_scalar = self.accum_loss.item()
@@ -501,6 +491,7 @@ def _criterion(outputs, inputs):
501491
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
502492
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
503493
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
494+
f"Sample_utilization: {sample_utilization:.4f}",
504495
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
505496
print("\n".join(to_log_msg))
506497
metrics = {
@@ -526,9 +517,15 @@ def _criterion(outputs, inputs):
526517
self.accum_advantages.zero_()
527518
self.accum_response_length.zero_()
528519
self.accum_count = 0
529-
return loss_scalar, num_excessive_samples // self.num_generations
520+
521+
if excessive_prompts_idx is not None:
522+
# All gather excessive prompts index across DP ranks.
523+
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
524+
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
525+
526+
return loss_scalar, excessive_prompts_idx
530527
else:
531-
return None, num_excessive_samples // self.num_generations
528+
return None, excessive_prompts_idx
532529

533530
def state_dict(self):
534531
self.policy_model._force_wait_all_gather()

applications/ColossalChat/coati/trainer/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,29 @@ def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
144144
else:
145145
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
146146
return tensor
147+
148+
149+
def all_gather_tensors(local_tensor_list: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
150+
"""
151+
Gathers tensors from all processes and concatenates them along the first dimension.
152+
153+
Args:
154+
tensor (torch.Tensor): The input tensor to be gathered.
155+
156+
Returns:
157+
torch.Tensor: The gathered tensor.
158+
"""
159+
# Gather tensors across DP group
160+
if plugin is not None:
161+
all_tensor_lists = [None] * plugin.dp_size
162+
dist.all_gather_object(all_tensor_lists, local_tensor_list, group=plugin.dp_group)
163+
gathered_tensor_list = []
164+
for tensors in all_tensor_lists:
165+
gathered_tensor_list.extend(tensors)
166+
else:
167+
all_tensor_lists = [None] * dist.get_world_size()
168+
dist.all_gather_object(all_tensor_lists, local_tensor_list)
169+
gathered_tensor_list = []
170+
for tensors in all_tensor_lists:
171+
gathered_tensor_list.extend(tensors)
172+
return gathered_tensor_list

applications/ColossalChat/rl_example.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
if __name__ == "__main__":
1010
parser = argparse.ArgumentParser()
1111
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
12-
parser.add_argument("-d", "--dataset", type=str, default="data_train.jsonl")
12+
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
1313
parser.add_argument(
1414
"-ed",
1515
"--eval-dataset",
@@ -30,7 +30,7 @@
3030
"-ibs",
3131
"--inference-batch-size",
3232
type=int,
33-
default=None,
33+
default=64,
3434
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
3535
)
3636
parser.add_argument(
@@ -51,7 +51,7 @@
5151
"-tMbs",
5252
"--train-minibatch-size",
5353
type=int,
54-
default=None,
54+
default=8,
5555
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
5656
)
5757
parser.add_argument(
@@ -68,7 +68,7 @@
6868
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
6969
)
7070
parser.add_argument(
71-
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
71+
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
7272
)
7373

7474
# Sampling parameters
@@ -238,7 +238,7 @@
238238
"zero_stage": 2,
239239
}, # for zero
240240
# plugin_config={
241-
# "tp_size": 1,
241+
# "tp_size": 2,
242242
# "pp_size": 2,
243243
# "microbatch_size": max(
244244
# 1, args.train_microbatch_size // 2

0 commit comments

Comments
 (0)