Skip to content

Commit 2a39d3a

Browse files
committed
address conversation
1 parent 4b1c515 commit 2a39d3a

File tree

2 files changed

+18
-43
lines changed

2 files changed

+18
-43
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,12 @@ def loop(self) -> None:
117117
# receive data from producers
118118
for r in range(self.num_producers):
119119
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
120-
raw_batch = unbind_batch(
121-
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
122-
)
120+
raw_batch = ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
123121
recv_effective_count = 0
124122
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
125123
# we need to calculate the metrics before filtering here for logging
126-
for group in raw_batch:
127-
group_with_reward = self.calculate_group_reward(group)
124+
raw_batch_with_reward = unbind_batch(self.calculate_reward(raw_batch))
125+
for group_with_reward in raw_batch_with_reward:
128126
group_reward_mean = group_with_reward["reward"].mean().cpu().item()
129127
group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item()
130128
group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item()
@@ -139,7 +137,8 @@ def loop(self) -> None:
139137
.cpu()
140138
.item()
141139
)
142-
filtered_group = self.prompt_level_filtering(group_with_reward)
140+
if self.grpo_config.get("dynamic_batching", True):
141+
filtered_group = self.prompt_level_filtering(group_with_reward)
143142
recv_effective_count += 1 if filtered_group is not None else 0
144143
self.buffer.append(
145144
[

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -218,30 +218,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
218218

219219
if self.grpo_config.get("dynamic_batching", True):
220220
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
221-
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
222-
223-
if excessive_prompts > 0:
224-
excessive_prompts_per_rank = excessive_prompts // self.dp_size
225-
# Only count excessive prompts if they are greater than 1 per rank.
226-
# TODO: customize excessive prompts calculation.
227-
if excessive_prompts_per_rank != 0:
228-
# Mask excessive prompts to False
229-
true_indices = torch.nonzero(effective_prompts_mask)
230-
# Make sure the indices are not empty.
231-
if true_indices.numel() > 0:
232-
true_indices = true_indices.squeeze(-1)
233-
if excessive_prompts_per_rank <= len(true_indices):
234-
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
235-
else:
236-
excessive_prompts_idx = true_indices
237-
effective_prompts_mask[excessive_prompts_idx] = False
238-
239-
for mask_idx in range(len(effective_prompts_mask)):
240-
if effective_prompts_mask[mask_idx] == False:
241-
# Update loss mask.
242-
loss_mask[mask_idx] = False
243-
else:
244-
excessive_prompts_idx = torch.empty([0])
245221
else:
246222
# If dynamic batching is disabled, we need to use all samples for training.
247223
need_update = (step_idx + 1) % self.num_microbatches == 0
@@ -510,7 +486,7 @@ def _criterion(outputs, inputs):
510486
else:
511487
return None
512488

513-
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
489+
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
514490
"""
515491
Calculate the group reward for the given rollout group.
516492
@@ -529,20 +505,20 @@ def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any
529505
Returns:
530506
Dict[str, Any]: The new group data with calculated reward.
531507
"""
532-
reward_group = self.reward_model(
533-
rollout_group["input_ids"],
534-
gt_answer=rollout_group["gt_answer"],
535-
response_idx=rollout_group["response_idx"],
508+
reward_model_output = self.reward_model(
509+
rollout["input_ids"],
510+
gt_answer=rollout["gt_answer"],
511+
response_idx=rollout["response_idx"],
536512
)
537513
# [num_of_generation]
538-
reward = torch.tensor([value[0] for value in reward_group]).to(rollout_group["input_ids"].device)
539-
format_acc = torch.tensor([value[1] for value in reward_group]).to(rollout_group["input_ids"].device)
540-
ans_acc = torch.tensor([value[2] for value in reward_group]).to(rollout_group["input_ids"].device)
541-
542-
rollout_group["reward"] = reward.view((-1, 1))
543-
rollout_group["format_acc"] = format_acc.view((-1, 1))
544-
rollout_group["ans_acc"] = ans_acc.view((-1, 1))
545-
return rollout_group
514+
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
515+
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
516+
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
517+
518+
rollout["reward"] = reward.view((-1, 1))
519+
rollout["format_acc"] = format_acc.view((-1, 1))
520+
rollout["ans_acc"] = ans_acc.view((-1, 1))
521+
return rollout
546522

547523
def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
548524
"""

0 commit comments

Comments
 (0)