Skip to content

Commit 11a5854

Browse files
committed
remove redundant code and fix bugs
1 parent a528921 commit 11a5854

File tree

5 files changed

+27
-59
lines changed

5 files changed

+27
-59
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,14 @@ def loop(self) -> None:
121121
raw_batch = unbind_batch(
122122
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
123123
)
124-
filtered_batch = [
125-
t
126-
for t in [
127-
self.prompt_level_filtering(self.calculate_group_reward(group))
128-
for group in raw_batch
129-
]
130-
if t is not None
124+
processed_batch = [
125+
self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
131126
]
127+
filtered_batch = [t for t in processed_batch if t is not None]
128+
if self.filter_range is not None:
129+
print(
130+
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
131+
)
132132

133133
self.buffer.extend(filtered_batch)
134134
while len(self.buffer) >= self.dp_size * self.minibatch_size:
@@ -137,13 +137,8 @@ def loop(self) -> None:
137137
]
138138
batch = bind_batch(batches)
139139
batch = post_recv(batch)
140-
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
141-
142-
if excessive_prompts_idx is not None:
143-
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
144-
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
145-
else:
146-
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
140+
loss = self.step(i, pbar, **batch)
141+
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
147142
if loss is not None:
148143
allow_sync_model = True
149144
pbar.set_postfix({"loss": loss})

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
1010
from coati.distributed.reward.verifiable_reward import VerifiableReward
1111
from coati.distributed.utils import calc_action_log_probs
12-
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
12+
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
1313
from transformers import AutoModelForCausalLM, AutoTokenizer
1414

1515
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -201,10 +201,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
201201
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
202202
# [minibatch_size x num_generations]
203203
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
204-
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
205-
group_ans_acc = (
206-
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
207-
)
204+
208205
# [minibatch_size x num_of_generation]
209206
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
210207

@@ -214,37 +211,14 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
214211
loss_mask,
215212
action_mask[:, -1] == False,
216213
)
217-
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
218-
219-
# [minibatch_size] -> calculate the number of effective prompts
220-
effective_prompts_mask = prompt_level_mask.any(dim=1)
221-
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
222-
self.effective_prompt_count += effective_prompts.item()
223-
excessive_prompts_idx = None
214+
self.effective_prompt_count += group_reward.size(0) * self.dp_size
224215

225216
mean_kl, mean_loss = [], []
226217

227218
if self.grpo_config.get("dynamic_batching", True):
228219
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
229220
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
230-
231-
if excessive_prompts > 0:
232-
excessive_prompts_per_rank = excessive_prompts // self.dp_size
233-
# Only count excessive prompts if they are greater than 1 per rank.
234-
# TODO: customize excessive prompts calculation.
235-
if excessive_prompts_per_rank != 0:
236-
# Mask excessive prompts to False
237-
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
238-
if excessive_prompts_per_rank <= len(true_indices):
239-
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
240-
else:
241-
excessive_prompts_idx = true_indices
242-
effective_prompts_mask[excessive_prompts_idx] = False
243-
244-
for mask_idx in range(len(effective_prompts_mask)):
245-
if effective_prompts_mask[mask_idx] == False:
246-
# Update loss mask.
247-
loss_mask[mask_idx] = False
221+
assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!"
248222
else:
249223
# If dynamic batching is disabled, we need to use all samples for training.
250224
need_update = (step_idx + 1) % self.num_microbatches == 0
@@ -460,9 +434,7 @@ def _criterion(outputs, inputs):
460434
self.optimizer.step()
461435
self.optimizer.zero_grad()
462436
self.global_step += 1
463-
self.total_sample_count = all_reduce_sum(
464-
torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin
465-
).item()
437+
# no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers.
466438
sample_utilization = self.effective_sample_count / self.total_sample_count
467439
self.effective_prompt_count = 0
468440
self.effective_sample_count = 0
@@ -507,14 +479,9 @@ def _criterion(outputs, inputs):
507479
self.accum_advantages.zero_()
508480
self.accum_response_length.zero_()
509481
self.accum_count = 0
510-
511-
if excessive_prompts_idx is not None:
512-
# All gather excessive prompts index across DP ranks.
513-
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
514-
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
515-
return loss_scalar, excessive_prompts_idx
482+
return loss_scalar
516483
else:
517-
return None, excessive_prompts_idx
484+
return None
518485

519486
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
520487
"""

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def launch_distributed(
6666

6767
dataset_path = train_dataset_config["path"]
6868
num_samples = get_jsonl_size_fast(dataset_path)
69-
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
69+
global_inference_batch_size = inference_batch_size * num_producers
7070
num_update_per_episode = num_samples // global_inference_batch_size
7171
num_recv_per_update = inference_batch_size // inference_microbatch_size
7272

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def loop(self) -> None:
187187
for eval_task_name in self.eval_dataloaders:
188188
if self.producer_idx == 0:
189189
print(
190-
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
190+
f"[P{self.producer_idx}] Evaluate consumer step {self.consumer_global_step} on task {eval_task_name}"
191191
)
192192
eval_results = []
193193
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)

applications/ColossalChat/rl_example.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,13 @@
104104
choices=["think_answer_tags", "boxed"],
105105
help="Reward type for GRPO.",
106106
)
107-
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
107+
parser.add_argument(
108+
"-ei",
109+
"--eval-interval",
110+
type=int,
111+
default=100,
112+
help="Interval for evaluation. Evaluate every ei consumer steps.",
113+
)
108114

109115
# Logging/Checkpointing parameters
110116
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@@ -125,8 +131,8 @@
125131
and args.train_microbatch_size > 0
126132
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
127133
assert (
128-
args.train_minibatch_size <= args.train_batch_size
129-
), "Train mini batch size must be less than or equals to train batch size"
134+
args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
135+
), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
130136

131137
if args.master_address is None:
132138
# Default settings: Using single machine

0 commit comments

Comments
 (0)