Skip to content

Commit 1644adf

Browse files
Tong LiYeAnbang
authored andcommitted
handle empty index
1 parent 957e3a5 commit 1644adf

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def loop(self) -> None:
113113
) as pbar:
114114
for step in pbar:
115115
i = 0
116-
allow_sync_model = False
117116
for _ in range(self.num_recv_per_update):
118117
# receive data from producers
119118
for r in range(self.num_producers):
@@ -139,7 +138,6 @@ def loop(self) -> None:
139138
else:
140139
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
141140
if loss is not None:
142-
allow_sync_model = True
143141
pbar.set_postfix({"loss": loss})
144142
i += 1
145143
if self.lr_scheduler is not None:
@@ -153,31 +151,29 @@ def loop(self) -> None:
153151
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
154152

155153
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
156-
if allow_sync_model:
157-
if self.pp_size > 1:
158-
print(
159-
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
154+
if self.pp_size > 1:
155+
print(
156+
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
157+
)
158+
else:
159+
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
160+
torch.cuda.empty_cache()
161+
state_dict = self.state_dict()
162+
if self.pp_size > 1:
163+
if self.tp_rank == 0 and self.dp_rank == 0:
164+
ray_broadcast_tensor_dict(
165+
state_dict,
166+
src=self.num_producers,
167+
device=self.device,
168+
group_name=f"sync_model_{self.pp_rank}",
160169
)
161-
else:
162-
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
163-
torch.cuda.empty_cache()
164-
state_dict = self.state_dict()
165-
if self.pp_size > 1:
166-
if self.tp_rank == 0 and self.dp_rank == 0:
167-
ray_broadcast_tensor_dict(
168-
state_dict,
169-
src=self.num_producers,
170-
device=self.device,
171-
group_name=f"sync_model_{self.pp_rank}",
172-
)
173-
else:
174-
if self.rank == 0:
175-
ray_broadcast_tensor_dict(
176-
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
177-
)
178-
del state_dict
179-
torch.cuda.empty_cache()
180-
allow_sync_model = False
170+
else:
171+
if self.rank == 0:
172+
ray_broadcast_tensor_dict(
173+
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
174+
)
175+
del state_dict
176+
torch.cuda.empty_cache()
181177

182178

183179
@ray.remote

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,17 +245,22 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
245245
# TODO: customize excessive prompts calculation.
246246
if excessive_prompts_per_rank != 0:
247247
# Mask excessive prompts to False
248-
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
249-
if excessive_prompts_per_rank <= len(true_indices):
250-
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
251-
else:
252-
excessive_prompts_idx = true_indices
253-
effective_prompts_mask[excessive_prompts_idx] = False
248+
true_indices = torch.nonzero(effective_prompts_mask)
249+
# Make sure the indices are not empty.
250+
if true_indices.numel() > 0:
251+
true_indices = true_indices.squeeze()
252+
if excessive_prompts_per_rank <= len(true_indices):
253+
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
254+
else:
255+
excessive_prompts_idx = true_indices
256+
effective_prompts_mask[excessive_prompts_idx] = False
254257

255-
for mask_idx in range(len(effective_prompts_mask)):
256-
if effective_prompts_mask[mask_idx] == False:
257-
# Update loss mask.
258-
loss_mask[mask_idx] = False
258+
for mask_idx in range(len(effective_prompts_mask)):
259+
if effective_prompts_mask[mask_idx] == False:
260+
# Update loss mask.
261+
loss_mask[mask_idx] = False
262+
else:
263+
excessive_prompts_idx = torch.empty([0])
259264
else:
260265
# If dynamic batching is disabled, we need to use all samples for training.
261266
need_update = (step_idx + 1) % self.num_microbatches == 0

0 commit comments

Comments
 (0)