Skip to content

Commit 55eee12

Browse files
committed
move prompt-level-filtering to buffer side
1 parent 957e3a5 commit 55eee12

File tree

2 files changed

+81
-24
lines changed

2 files changed

+81
-24
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,24 @@ def loop(self) -> None:
113113
) as pbar:
114114
for step in pbar:
115115
i = 0
116-
allow_sync_model = False
116+
allow_sync_model = True
117117
for _ in range(self.num_recv_per_update):
118118
# receive data from producers
119119
for r in range(self.num_producers):
120120
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
121-
self.buffer.extend(
122-
unbind_batch(
123-
ray_broadcast_tensor_dict(
124-
None, src=0, device=self.device, group_name=f"sync_data_{r}"
125-
)
126-
)
121+
raw_batch = unbind_batch(
122+
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
127123
)
124+
filtered_batch = [
125+
t
126+
for t in [
127+
self.prompt_level_filtering(self.calculate_group_reward(group))
128+
for group in raw_batch
129+
]
130+
if t is not None
131+
]
132+
133+
self.buffer.extend(filtered_batch)
128134
while len(self.buffer) >= self.dp_size * self.minibatch_size:
129135
batches = self.buffer[
130136
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
@@ -177,7 +183,7 @@ def loop(self) -> None:
177183
)
178184
del state_dict
179185
torch.cuda.empty_cache()
180-
allow_sync_model = False
186+
allow_sync_model = True
181187

182188

183189
@ray.remote

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from contextlib import nullcontext
2-
from typing import Any, Optional
2+
from typing import Any, Dict, Optional
33

44
import ray
55
import torch
@@ -179,7 +179,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
179179
Format:
180180
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
181181
"""
182-
183182
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
184183
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
185184
action_mask = data["action_mask"]
@@ -188,15 +187,9 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
188187
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
189188
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
190189

191-
reward_group = self.reward_model(
192-
data["input_ids"],
193-
gt_answer=data["gt_answer"],
194-
response_idx=data["response_idx"],
195-
)
196-
197-
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
198-
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
199-
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
190+
reward = data["reward"].view((-1))
191+
format_acc = data["format_acc"].view((-1))
192+
ans_acc = data["ans_acc"].view((-1))
200193

201194
# [minibatch_size, num_generations]
202195

@@ -213,11 +206,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
213206
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
214207
)
215208
# [minibatch_size x num_of_generation]
216-
loss_mask = (
217-
torch.ones(action_mask.size(0), device=action_mask.device).bool()
218-
if self.filter_range is None
219-
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
220-
)
209+
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
221210

222211
# filter out overlength samples
223212
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
@@ -525,6 +514,68 @@ def _criterion(outputs, inputs):
525514
else:
526515
return None, excessive_prompts_idx
527516

517+
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
518+
"""
519+
Calculate the group reward for the given rollout group.
520+
521+
Args:
522+
rollout_group (Dict[str, Any]):
523+
a group of samples generated by the model from the same prompt
524+
contain the following keys:
525+
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
526+
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
527+
"action_mask": torch.Tensor, [num_of_generation, response_length]
528+
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
529+
"response_idx": int, torch.Tensor, [num_of_generation, 2]
530+
"gt_answer": torch.Tensor, [num_of_generation, 128]
531+
"temperature": torch.Tensor, [] (scalar)
532+
533+
Returns:
534+
Dict[str, Any]: The new group data with calculated reward.
535+
"""
536+
reward_group = self.reward_model(
537+
rollout_group["input_ids"],
538+
gt_answer=rollout_group["gt_answer"],
539+
response_idx=rollout_group["response_idx"],
540+
)
541+
# [num_of_generation]
542+
reward = torch.tensor([value[0] for value in reward_group]).to(rollout_group["input_ids"].device)
543+
format_acc = torch.tensor([value[1] for value in reward_group]).to(rollout_group["input_ids"].device)
544+
ans_acc = torch.tensor([value[2] for value in reward_group]).to(rollout_group["input_ids"].device)
545+
546+
rollout_group["reward"] = reward.view((-1, 1))
547+
rollout_group["format_acc"] = format_acc.view((-1, 1))
548+
rollout_group["ans_acc"] = ans_acc.view((-1, 1))
549+
return rollout_group
550+
551+
def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
552+
"""
553+
rollout_group: Dict[str, Any]
554+
a group of samples generated by the model from the same prompt
555+
contain the following keys:
556+
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
557+
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
558+
"action_mask": torch.Tensor, [num_of_generation, response_length]
559+
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
560+
"response_idx": int, torch.Tensor, [num_of_generation, 2]
561+
"gt_answer": torch.Tensor, [num_of_generation, 128]
562+
"temperature": torch.Tensor, [] (scalar)
563+
"reward": torch.Tensor, [num_of_generation]
564+
"format_acc": torch.Tensor, [num_of_generation]
565+
"ans_acc": torch.Tensor, [num_of_generation]
566+
"""
567+
if self.filter_range is not None:
568+
# filter prompt whoes accuracy is too high or too low (out of range)
569+
group_ans_acc = torch.mean(rollout_group["ans_acc"])
570+
if group_ans_acc < self.filter_range[0] or group_ans_acc > self.filter_range[1]:
571+
# filter out the prompt
572+
return None
573+
else:
574+
return rollout_group
575+
else:
576+
# no filter
577+
return rollout_group
578+
528579
def state_dict(self):
529580
self.policy_model._force_wait_all_gather()
530581
model = self.policy_model.unwrap()

0 commit comments

Comments
 (0)