Skip to content

Commit c865de3

Browse files
committed
cherry pick zero bubble RL
1 parent 2336d7f commit c865de3

File tree

4 files changed

+8
-17
lines changed

4 files changed

+8
-17
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -425,20 +425,6 @@ def _criterion(outputs, inputs):
425425
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
426426
mean_kl.append(kl)
427427
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
428-
mini_batch_entropies.append(
429-
all_reduce_mean(
430-
(
431-
(
432-
(
433-
entropy_from_logits(policy_model_logits[:, -num_action:])
434-
* action_mask_forward_micro_batch
435-
).sum(-1)
436-
)
437-
/ action_mask_forward_micro_batch.sum(-1)
438-
).detach(),
439-
self.plugin,
440-
)
441-
)
442428
else:
443429
policy_model_logits = self.policy_model(
444430
input_ids=input_ids_forward_micro_batch,

applications/ColossalChat/coati/distributed/zero_bubble/distributor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ def loop(self):
6464
)
6565
self.profiler.exit(f"sync_model_consumer_pp_{i}")
6666
self.weight_version[i] += 1
67-
for i in range(self.consumer_pp_size):
68-
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
67+
if all(
68+
[signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)]
69+
):
70+
for i in range(self.consumer_pp_size):
6971
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
7072
# Broadcast the model state dict to all producers
7173
ray.get(
@@ -116,4 +118,4 @@ def loop(self):
116118
ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version))
117119

118120
def get_weight_version(self):
119-
return min(self.weight_version)
121+
return self.weight_version[0]

applications/ColossalChat/coati/distributed/zero_bubble/producer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def sync_model_thread():
244244
f"producer_{self.producer_idx}_pp_{pp_idx}", "ready_sync_model"
245245
)
246246
)
247+
for pp_idx in range(self.consumer_pp_size):
247248
print(
248249
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
249250
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ray==2.49.2
2+
pygloo>=0.2.0 # you need to build from source: https://github.com/ray-project/pygloo commit 82ae2d72222aefcac54a8e88995735ede3abe9cf https://github.com/ray-project/pygloo/blob/main/README.md

0 commit comments

Comments
 (0)