Skip to content

Commit aca5476

Browse files
TongLi3701Tong Li
andauthored
[feat] Support prompt level dynamic (#6300)
* adjust to dynamic prompt bs * remove debug * update pad seq (#6303) Co-authored-by: Tong Li <[email protected]> * adjust to dynamic prompt bs * remove debug * fix dp issue * fix * fix default settings --------- Co-authored-by: Tong Li <[email protected]>
1 parent b920af4 commit aca5476

File tree

4 files changed

+121
-91
lines changed

4 files changed

+121
-91
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,14 @@ def loop(self) -> None:
107107
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
108108
)
109109
for episode in range(self.num_episodes):
110-
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
110+
with tqdm(
111+
range(self.num_update_per_episode),
112+
desc=f"Episode {episode} with rollout step(s)",
113+
disable=self.rank != 0,
114+
) as pbar:
111115
for step in pbar:
112116
i = 0
117+
allow_sync_model = False
113118
for _ in range(self.num_recv_per_update):
114119
# receive data from producers
115120
for r in range(self.num_producers):
@@ -127,15 +132,15 @@ def loop(self) -> None:
127132
]
128133
batch = bind_batch(batches)
129134
batch = post_recv(batch)
130-
loss, num_excessive_prompts = self.step(i, pbar, **batch)
131-
self.buffer = (
132-
self.buffer[
133-
(self.dp_rank + 1) * self.minibatch_size
134-
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
135-
]
136-
+ self.buffer[self.dp_size * self.minibatch_size :]
137-
)
135+
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
136+
137+
if excessive_prompts_idx is not None:
138+
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
139+
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
140+
else:
141+
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
138142
if loss is not None:
143+
allow_sync_model = True
139144
pbar.set_postfix({"loss": loss})
140145
i += 1
141146
if self.lr_scheduler is not None:
@@ -149,29 +154,31 @@ def loop(self) -> None:
149154
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
150155

151156
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
152-
if self.pp_size > 1:
153-
print(
154-
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
155-
)
156-
else:
157-
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
158-
torch.cuda.empty_cache()
159-
state_dict = self.state_dict()
160-
if self.pp_size > 1:
161-
if self.tp_rank == 0 and self.dp_rank == 0:
162-
ray_broadcast_tensor_dict(
163-
state_dict,
164-
src=self.num_producers,
165-
device=self.device,
166-
group_name=f"sync_model_{self.pp_rank}",
167-
)
168-
else:
169-
if self.rank == 0:
170-
ray_broadcast_tensor_dict(
171-
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
157+
if allow_sync_model:
158+
if self.pp_size > 1:
159+
print(
160+
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
172161
)
173-
del state_dict
174-
torch.cuda.empty_cache()
162+
else:
163+
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
164+
torch.cuda.empty_cache()
165+
state_dict = self.state_dict()
166+
if self.pp_size > 1:
167+
if self.tp_rank == 0 and self.dp_rank == 0:
168+
ray_broadcast_tensor_dict(
169+
state_dict,
170+
src=self.num_producers,
171+
device=self.device,
172+
group_name=f"sync_model_{self.pp_rank}",
173+
)
174+
else:
175+
if self.rank == 0:
176+
ray_broadcast_tensor_dict(
177+
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
178+
)
179+
del state_dict
180+
torch.cuda.empty_cache()
181+
allow_sync_model = False
175182

176183

177184
@ray.remote

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 52 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from contextlib import nullcontext
32
from typing import Any, Optional
43

@@ -10,7 +9,7 @@
109
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
1110
from coati.distributed.reward.verifiable_reward import VerifiableReward
1211
from coati.distributed.utils import calc_action_log_probs
13-
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
12+
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
1413
from transformers import AutoModelForCausalLM, AutoTokenizer
1514

1615
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -42,13 +41,6 @@ def __init__(
4241
save_dir="./model",
4342
):
4443
print(f"Using GRPO config: {grpo_config}")
45-
if grpo_config.get("loss_variation", "sample_level") == "token_level":
46-
if batch_size != minibatch_size:
47-
warnings.warn(
48-
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
49-
UserWarning,
50-
)
51-
minibatch_size = batch_size
5244
if (
5345
plugin_config.get("pp_size", 1) > 1
5446
and "num_microbatches" not in plugin_config
@@ -90,6 +82,7 @@ def __init__(
9082
self.grpo_config = grpo_config
9183
self.project_name = project_name
9284
self.effective_sample_count = 0
85+
self.effective_prompt_count = 0
9386
self.total_sample_count = 0
9487

9588
self.policy_loss_fn = PolicyLoss(
@@ -213,70 +206,66 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
213206
group_ans_acc = (
214207
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
215208
)
209+
# [minibatch_size x num_of_generation]
216210
loss_mask = (
217211
torch.ones(action_mask.size(0), device=action_mask.device).bool()
218212
if self.filter_range is None
219213
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
220214
)
215+
221216
# filter out overlength samples
222217
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
223218
loss_mask = torch.logical_and(
224219
loss_mask,
225220
action_mask[:, -1] == False,
226221
)
227-
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
228-
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
229-
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
230-
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
231-
self.effective_sample_count += effective_samples.item()
232-
self.total_sample_count += total_samples.item()
222+
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
223+
224+
# [minibatch_size] -> calculate the number of effective prompts
225+
effective_prompts_mask = prompt_level_mask.any(dim=1)
226+
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
227+
self.effective_prompt_count += effective_prompts.item()
228+
excessive_prompts_idx = None
233229

234230
mean_kl, mean_loss = [], []
235231

236232
if self.grpo_config.get("dynamic_batching", True):
237-
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
238-
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
239-
num_excessive_samples = (
240-
int(
241-
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
242-
/ self.num_generations
243-
/ self.dp_size
244-
)
245-
* self.num_generations
246-
)
247-
if num_excessive_samples > 0:
248-
data = {
249-
k: (
250-
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
251-
if k
252-
in [
253-
"input_ids",
254-
"attention_mask",
255-
"action_log_probs",
256-
"action_mask",
257-
"response_idx",
258-
"gt_answer",
259-
]
260-
else v
261-
)
262-
for k, v in data.items()
263-
}
264-
action_mask = action_mask[
265-
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
266-
]
267-
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
268-
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
269-
else:
270-
num_excessive_samples = 0
233+
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
234+
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
235+
236+
if excessive_prompts > 0:
237+
excessive_prompts_per_rank = excessive_prompts // self.dp_size
238+
# Only count excessive prompts if they are greater than 1 per rank.
239+
# TODO: customize excessive prompts calculation.
240+
if excessive_prompts_per_rank != 0:
241+
# Mask excessive prompts to False
242+
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
243+
if excessive_prompts_per_rank <= len(true_indices):
244+
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
245+
else:
246+
excessive_prompts_idx = true_indices
247+
effective_prompts_mask[excessive_prompts_idx] = False
248+
249+
for mask_idx in range(len(effective_prompts_mask)):
250+
if effective_prompts_mask[mask_idx] == False:
251+
# Update loss mask.
252+
loss_mask[mask_idx] = False
271253
else:
272254
# If dynamic batching is disabled, we need to use all samples for training.
273255
need_update = (step_idx + 1) % self.num_microbatches == 0
274-
num_excessive_samples = 0
256+
257+
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
258+
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
259+
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
260+
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
261+
self.effective_sample_count += effective_samples.item()
262+
self.total_sample_count += total_samples.item()
275263

276264
pbar.set_postfix(
277265
{
278-
"Step": self.global_step + 1,
279-
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
266+
"Global Step": self.global_step,
267+
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
268+
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
280269
}
281270
)
282271

@@ -375,7 +364,7 @@ def _criterion(outputs, inputs):
375364
kl.append(appox_kl.mean())
376365
else:
377366
per_token_kl = 0.0
378-
kl.append(0.0)
367+
kl.append(torch.tensor(0.0))
379368

380369
loss, _ = self.policy_loss_fn(
381370
action_log_probs,
@@ -479,6 +468,7 @@ def _criterion(outputs, inputs):
479468
self.optimizer.zero_grad()
480469
self.global_step += 1
481470
sample_utilization = self.effective_sample_count / self.total_sample_count
471+
self.effective_prompt_count = 0
482472
self.effective_sample_count = 0
483473
self.total_sample_count = 0
484474
loss_scalar = self.accum_loss.item()
@@ -495,6 +485,7 @@ def _criterion(outputs, inputs):
495485
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
496486
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
497487
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
488+
f"Sample_utilization: {sample_utilization:.4f}",
498489
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
499490
print("\n".join(to_log_msg))
500491
metrics = {
@@ -520,9 +511,15 @@ def _criterion(outputs, inputs):
520511
self.accum_advantages.zero_()
521512
self.accum_response_length.zero_()
522513
self.accum_count = 0
523-
return loss_scalar, num_excessive_samples // self.num_generations
514+
515+
if excessive_prompts_idx is not None:
516+
# All gather excessive prompts index across DP ranks.
517+
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
518+
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
519+
520+
return loss_scalar, excessive_prompts_idx
524521
else:
525-
return None, num_excessive_samples // self.num_generations
522+
return None, excessive_prompts_idx
526523

527524
def state_dict(self):
528525
self.policy_model._force_wait_all_gather()

applications/ColossalChat/coati/trainer/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,29 @@ def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
144144
else:
145145
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
146146
return tensor
147+
148+
149+
def all_gather_tensors(local_tensor_list: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
150+
"""
151+
Gathers tensors from all processes and concatenates them along the first dimension.
152+
153+
Args:
154+
tensor (torch.Tensor): The input tensor to be gathered.
155+
156+
Returns:
157+
torch.Tensor: The gathered tensor.
158+
"""
159+
# Gather tensors across DP group
160+
if plugin is not None:
161+
all_tensor_lists = [None] * plugin.dp_size
162+
dist.all_gather_object(all_tensor_lists, local_tensor_list, group=plugin.dp_group)
163+
gathered_tensor_list = []
164+
for tensors in all_tensor_lists:
165+
gathered_tensor_list.extend(tensors)
166+
else:
167+
all_tensor_lists = [None] * dist.get_world_size()
168+
dist.all_gather_object(all_tensor_lists, local_tensor_list)
169+
gathered_tensor_list = []
170+
for tensors in all_tensor_lists:
171+
gathered_tensor_list.extend(tensors)
172+
return gathered_tensor_list

applications/ColossalChat/rl_example.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
parser = argparse.ArgumentParser()
1010
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
1111
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
12-
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
12+
parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.")
1313
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
1414

1515
# Distributed training parameters
@@ -20,7 +20,7 @@
2020
"-ibs",
2121
"--inference-batch-size",
2222
type=int,
23-
default=None,
23+
default=64,
2424
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
2525
)
2626
parser.add_argument(
@@ -41,7 +41,7 @@
4141
"-tMbs",
4242
"--train-minibatch-size",
4343
type=int,
44-
default=None,
44+
default=8,
4545
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
4646
)
4747
parser.add_argument(
@@ -58,7 +58,7 @@
5858
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
5959
)
6060
parser.add_argument(
61-
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
61+
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
6262
)
6363

6464
# Sampling parameters
@@ -223,7 +223,7 @@
223223
"zero_stage": 2,
224224
}, # for zero
225225
# plugin_config={
226-
# "tp_size": 1,
226+
# "tp_size": 2,
227227
# "pp_size": 2,
228228
# "microbatch_size": max(
229229
# 1, args.train_microbatch_size // 2

0 commit comments

Comments
 (0)