Skip to content

Commit ab95624

Browse files
TongLi3701Tong Li
andauthored
handle empty index (#6311)
Co-authored-by: Tong Li <[email protected]>
1 parent aca5476 commit ab95624

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def loop(self) -> None:
114114
) as pbar:
115115
for step in pbar:
116116
i = 0
117-
allow_sync_model = False
118117
for _ in range(self.num_recv_per_update):
119118
# receive data from producers
120119
for r in range(self.num_producers):
@@ -140,7 +139,6 @@ def loop(self) -> None:
140139
else:
141140
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
142141
if loss is not None:
143-
allow_sync_model = True
144142
pbar.set_postfix({"loss": loss})
145143
i += 1
146144
if self.lr_scheduler is not None:
@@ -154,31 +152,29 @@ def loop(self) -> None:
154152
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
155153

156154
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
157-
if allow_sync_model:
158-
if self.pp_size > 1:
159-
print(
160-
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
155+
if self.pp_size > 1:
156+
print(
157+
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
158+
)
159+
else:
160+
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
161+
torch.cuda.empty_cache()
162+
state_dict = self.state_dict()
163+
if self.pp_size > 1:
164+
if self.tp_rank == 0 and self.dp_rank == 0:
165+
ray_broadcast_tensor_dict(
166+
state_dict,
167+
src=self.num_producers,
168+
device=self.device,
169+
group_name=f"sync_model_{self.pp_rank}",
161170
)
162-
else:
163-
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
164-
torch.cuda.empty_cache()
165-
state_dict = self.state_dict()
166-
if self.pp_size > 1:
167-
if self.tp_rank == 0 and self.dp_rank == 0:
168-
ray_broadcast_tensor_dict(
169-
state_dict,
170-
src=self.num_producers,
171-
device=self.device,
172-
group_name=f"sync_model_{self.pp_rank}",
173-
)
174-
else:
175-
if self.rank == 0:
176-
ray_broadcast_tensor_dict(
177-
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
178-
)
179-
del state_dict
180-
torch.cuda.empty_cache()
181-
allow_sync_model = False
171+
else:
172+
if self.rank == 0:
173+
ray_broadcast_tensor_dict(
174+
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
175+
)
176+
del state_dict
177+
torch.cuda.empty_cache()
182178

183179

184180
@ray.remote

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,17 +239,22 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
239239
# TODO: customize excessive prompts calculation.
240240
if excessive_prompts_per_rank != 0:
241241
# Mask excessive prompts to False
242-
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
243-
if excessive_prompts_per_rank <= len(true_indices):
244-
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
245-
else:
246-
excessive_prompts_idx = true_indices
247-
effective_prompts_mask[excessive_prompts_idx] = False
242+
true_indices = torch.nonzero(effective_prompts_mask)
243+
# Make sure the indices are not empty.
244+
if true_indices.numel() > 0:
245+
true_indices = true_indices.squeeze()
246+
if excessive_prompts_per_rank <= len(true_indices):
247+
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
248+
else:
249+
excessive_prompts_idx = true_indices
250+
effective_prompts_mask[excessive_prompts_idx] = False
248251

249-
for mask_idx in range(len(effective_prompts_mask)):
250-
if effective_prompts_mask[mask_idx] == False:
251-
# Update loss mask.
252-
loss_mask[mask_idx] = False
252+
for mask_idx in range(len(effective_prompts_mask)):
253+
if effective_prompts_mask[mask_idx] == False:
254+
# Update loss mask.
255+
loss_mask[mask_idx] = False
256+
else:
257+
excessive_prompts_idx = torch.empty([0])
253258
else:
254259
# If dynamic batching is disabled, we need to use all samples for training.
255260
need_update = (step_idx + 1) % self.num_microbatches == 0

0 commit comments

Comments
 (0)