Skip to content

Commit bd61918

Browse files
committed
reuse comm-group
1 parent 57a8839 commit bd61918

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ def setup(self) -> None:
9494
if self.rank == 0:
9595
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
9696

97-
for i in range(self.num_producers):
98-
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}")
99-
10097
self.buffer = []
10198
self.recv_cnt = 0
10299

@@ -116,11 +113,14 @@ def loop(self) -> None:
116113
i = 0
117114
if self.eval_interval > 0 and step % self.eval_interval == 0:
118115
eval_statistics = None
116+
eval_global_step = None
119117
for r in range(self.num_producers):
120118
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
121119
local_eval_result = ray_broadcast_tensor_dict(
122-
None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}"
120+
None, src=0, device=self.device, group_name=f"sync_data_{r}"
123121
)
122+
assert "consumer_global_step" in local_eval_result
123+
eval_global_step = local_eval_result.pop("consumer_global_step").item()
124124
if eval_statistics is None:
125125
eval_statistics = local_eval_result
126126
else:
@@ -129,8 +129,8 @@ def loop(self) -> None:
129129
}
130130
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
131131
if dist.get_rank() == 0:
132-
if hasattr(self, "wandb_run") and hasattr(self, "global_step"):
133-
self.wandb_run.log(eval_statistics, step=self.global_step)
132+
if hasattr(self, "wandb_run"):
133+
self.wandb_run.log(eval_statistics, step=eval_global_step)
134134
print(f"Eval statistics: {eval_statistics}")
135135
for _ in range(self.num_recv_per_update):
136136
# receive data from producers

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def setup(self) -> None:
138138
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
139139
else:
140140
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
141-
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}")
142141

143142
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
144143
raise NotImplementedError
@@ -194,11 +193,14 @@ def loop(self) -> None:
194193
# delete the file if it exists
195194
safe_write_jsonl(result_file_name, eval_results)
196195
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
196+
eval_statistics["consumer_global_step"] = torch.tensor(
197+
[self.consumer_global_step], device=self.device
198+
)
197199
ray_broadcast_tensor_dict(
198200
eval_statistics,
199201
src=0,
200202
device=self.device,
201-
group_name=f"sync_eval_statistics_{self.producer_idx}",
203+
group_name=f"sync_data_{self.producer_idx}",
202204
)
203205
outputs = self.rollout(**batch)
204206

0 commit comments

Comments
 (0)