Skip to content

Commit 3bed6ae

Browse files
committed
fix bug, tested
1 parent dc3033e commit 3bed6ae

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,7 @@ applications/ColossalChat/wandb
167167
applications/ColossalChat/model
168168
applications/ColossalChat/eval
169169
applications/ColossalChat/rollouts
170+
applications/ColossalChat/*.txt
171+
applications/ColossalChat/*.db
172+
applications/ColossalChat/stdin
173+
applications/ColossalChat/*.zip

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,7 @@ def loop(self) -> None:
132132
format_acc = raw_batch["format_acc"][:, :, 0]
133133
ans_acc = raw_batch["ans_acc"][:, :, 0]
134134
response_len = (
135-
raw_batch["response_idx"][:, :, 1]
136-
- raw_batch["response_idx"][:, :, 0]
137-
+ 1
135+
raw_batch["response_idx"][:, :, 1] - raw_batch["response_idx"][:, :, 0] + 1
138136
).type(torch.float32)
139137
effective_group_mask = None
140138
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
@@ -160,7 +158,7 @@ def loop(self) -> None:
160158
)
161159
if effective_group_mask is not None:
162160
print(
163-
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
161+
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
164162
)
165163
# mapping the effective group to the raw group for indexing
166164
effective_group_to_raw_group_mapping = {}

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def loop(self) -> None:
291291
reward_model_output = self.reward_model(
292292
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
293293
gt_answer=gt_answer,
294-
response_idx=outputs["response_idx"],
294+
response_idx=outputs["response_idx"].view((-1, 2)),
295295
)
296296
outputs["reward"] = (
297297
torch.tensor([value[0] for value in reward_model_output])

0 commit comments

Comments
 (0)