Skip to content

Commit ceb7065

Browse files
authored
Merge pull request #6312 from hpcaitech/grpo-latest-dev
[feat] Move prompt-level-filtering to buffer side
2 parents c8b368c + 96faf54 commit ceb7065

File tree

5 files changed

+251
-143
lines changed

5 files changed

+251
-143
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,26 +117,102 @@ def loop(self) -> None:
117117
# receive data from producers
118118
for r in range(self.num_producers):
119119
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
120-
self.buffer.extend(
121-
unbind_batch(
122-
ray_broadcast_tensor_dict(
123-
None, src=0, device=self.device, group_name=f"sync_data_{r}"
124-
)
125-
)
120+
raw_batch = ray_broadcast_tensor_dict(
121+
None, src=0, device=self.device, group_name=f"sync_data_{r}"
126122
)
127-
while len(self.buffer) >= self.dp_size * self.minibatch_size:
128-
batches = self.buffer[
129-
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
123+
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124+
# we need to calculate the metrics before filtering here for logging
125+
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
126+
raw_batch_with_reward = self.calculate_reward(
127+
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
128+
)
129+
raw_batch_with_reward = {
130+
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
131+
for k, v in raw_batch_with_reward.items()
132+
}
133+
# [batch_size, num_generations] -> [batch_size]
134+
reward = raw_batch_with_reward["reward"][:, :, 0]
135+
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
136+
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
137+
response_len = (
138+
raw_batch_with_reward["response_idx"][:, :, 1]
139+
- raw_batch_with_reward["response_idx"][:, :, 0]
140+
+ 1
141+
).type(torch.float32)
142+
effective_group_mask = None
143+
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
144+
# filter the group based on the reward and accuracy
145+
group_ans_acc_mean = ans_acc.mean(dim=1)
146+
effective_group_mask = torch.logical_and(
147+
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
148+
)
149+
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
150+
for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
151+
self.buffer.append(
152+
[
153+
(
154+
group_with_reward
155+
if effective_group_mask is None or effective_group_mask[group_idx]
156+
else None
157+
),
158+
reward[group_idx],
159+
format_acc[group_idx],
160+
ans_acc[group_idx],
161+
response_len[group_idx],
162+
]
163+
)
164+
if effective_group_mask is not None:
165+
print(
166+
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
167+
)
168+
# mapping the effective group to the raw group for indexing
169+
effective_group_to_raw_group_mapping = {}
170+
for buffer_idx in range(len(self.buffer)):
171+
if self.buffer[buffer_idx][0] is not None:
172+
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
173+
buffer_idx
174+
)
175+
print(
176+
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
177+
)
178+
179+
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
180+
# on each dp_rank, we use minibatch_size effective samples to form a batch
181+
batches = [
182+
self.buffer[effective_group_to_raw_group_mapping[i]]
183+
for i in range(
184+
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
185+
)
130186
]
131-
batch = bind_batch(batches)
187+
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
188+
# each mini-batch use the first self.dp_size * minibatch_size effective samples
189+
raw_mini_batches = self.buffer[
190+
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
191+
] # include the last effective sample
192+
raw_mini_batches_metric_dict = {
193+
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
194+
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
195+
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
196+
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
197+
}
198+
batch = bind_batch([t[0] for t in batches])
132199
batch = post_recv(batch)
133-
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
134-
135-
if excessive_prompts_idx is not None:
136-
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
137-
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
138-
else:
139-
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
200+
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
201+
self.buffer = self.buffer[
202+
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
203+
]
204+
# recalculate the effective group to raw group mapping
205+
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
206+
effective_group_to_raw_group_mapping = {}
207+
for buffer_idx in range(len(self.buffer)):
208+
if self.buffer[buffer_idx][0] is not None:
209+
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
210+
buffer_idx
211+
)
212+
assert (
213+
len(effective_group_to_raw_group_mapping)
214+
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
215+
)
140216
if loss is not None:
141217
pbar.set_postfix({"loss": loss})
142218
i += 1

0 commit comments

Comments
 (0)