Skip to content

Commit 0d00811

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7b921ac commit 0d00811

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,32 +117,44 @@ def loop(self) -> None:
117117
# receive data from producers
118118
for r in range(self.num_producers):
119119
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
120-
raw_batch = ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
120+
raw_batch = ray_broadcast_tensor_dict(
121+
None, src=0, device=self.device, group_name=f"sync_data_{r}"
122+
)
121123
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
122124
# we need to calculate the metrics before filtering here for logging
123125
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
124-
raw_batch_with_reward = self.calculate_reward({k:v.view(-1, v.size(-1)) if k!='temperature' else v for k, v in raw_batch.items()})
125-
raw_batch_with_reward = {k: v.view(-1, self.num_generations, v.size(-1)) if k!='temperature' else v for k, v in raw_batch_with_reward.items()}
126+
raw_batch_with_reward = self.calculate_reward(
127+
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
128+
)
129+
raw_batch_with_reward = {
130+
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
131+
for k, v in raw_batch_with_reward.items()
132+
}
126133
# [batch_size, num_generations] -> [batch_size]
127-
reward = raw_batch_with_reward["reward"][:,:,0]
128-
format_acc = raw_batch_with_reward["format_acc"][:,:,0]
129-
ans_acc = raw_batch_with_reward["ans_acc"][:,:,0]
134+
reward = raw_batch_with_reward["reward"][:, :, 0]
135+
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
136+
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
130137
response_len = (
131-
(raw_batch_with_reward["response_idx"][:, :, 1] - raw_batch_with_reward["response_idx"][:, :, 0] + 1)
132-
.type(torch.float32)
133-
)
138+
raw_batch_with_reward["response_idx"][:, :, 1]
139+
- raw_batch_with_reward["response_idx"][:, :, 0]
140+
+ 1
141+
).type(torch.float32)
134142
effective_group_mask = None
135143
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
136144
# filter the group based on the reward and accuracy
137145
group_ans_acc_mean = ans_acc.mean(dim=1)
138146
effective_group_mask = torch.logical_and(
139147
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
140148
)
141-
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
149+
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
142150
for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
143151
self.buffer.append(
144152
[
145-
group_with_reward if effective_group_mask is None or effective_group_mask[group_idx] else None,
153+
(
154+
group_with_reward
155+
if effective_group_mask is None or effective_group_mask[group_idx]
156+
else None
157+
),
146158
reward[group_idx],
147159
format_acc[group_idx],
148160
ans_acc[group_idx],
@@ -160,7 +172,9 @@ def loop(self) -> None:
160172
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
161173
buffer_idx
162174
)
163-
print(f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}")
175+
print(
176+
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
177+
)
164178

165179
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
166180
# on each dp_rank, we use minibatch_size effective samples to form a batch

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
211211
loss_mask,
212212
action_mask[:, -1] == False,
213213
)
214-
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False)==False:
214+
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False:
215215
# filter out samples with reward outside the range
216216
# if dynamic batching is enabled, we filter out out of range groups before training
217-
group_ans_acc_mean = ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
217+
group_ans_acc_mean = (
218+
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
219+
)
218220
loss_mask = torch.logical_and(
219221
loss_mask,
220222
torch.logical_and(
@@ -454,7 +456,9 @@ def _criterion(outputs, inputs):
454456
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
455457
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
456458
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
457-
overlength_samples_ratio = (raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item() # not an exact figure, but a close estimate
459+
overlength_samples_ratio = (
460+
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
461+
) # not an exact figure, but a close estimate
458462
self.raw_train_batch_reward = []
459463
self.raw_train_batch_format_acc = []
460464
self.raw_train_batch_ans_acc = []

0 commit comments

Comments
 (0)