Skip to content

Commit 9467c10

Browse files
YeAnbangTong Li
andauthored
[hot-fix] Fix memory leakage bug, support TP+PP (#6258)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache --------- Co-authored-by: Tong Li <[email protected]>
1 parent ed43a4b commit 9467c10

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def setup(self) -> None:
7272
self.plugin = HybridParallelPlugin(**plugin_config)
7373
self.booster = Booster(plugin=self.plugin)
7474
self.dp_rank = dist.get_rank(self.plugin.dp_group)
75+
self.tp_rank = dist.get_rank(self.plugin.tp_group)
76+
7577
self.dp_size = dist.get_world_size(self.plugin.dp_group)
7678

7779
self.buffer = []
@@ -127,11 +129,14 @@ def loop(self) -> None:
127129

128130
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
129131
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
132+
torch.cuda.empty_cache()
130133
state_dict = self.state_dict()
131134
if self.rank == 0:
132135
ray_broadcast_tensor_dict(
133136
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
134137
)
138+
del state_dict
139+
torch.cuda.empty_cache()
135140

136141

137142
@ray.remote

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def setup(self):
109109
super().setup()
110110
if self.use_wandb and (
111111
(not self.plugin.pp_size > 1 and self.rank == 0)
112-
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage())
112+
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
113113
):
114114
# Initialize wandb.
115115
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
@@ -282,10 +282,9 @@ def _criterion(outputs, inputs):
282282

283283
if self.booster.plugin.stage_manager.is_last_stage():
284284
if len(kl) > 0:
285-
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
285+
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
286286
mean_kl.append(kl)
287-
loss = all_reduce_mean(loss, self.plugin)
288-
mean_loss.append(loss.data)
287+
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
289288
else:
290289

291290
policy_model_logits = self.policy_model(
@@ -336,7 +335,7 @@ def _criterion(outputs, inputs):
336335
mean_kl.append(kl.data)
337336
mean_loss.append(loss.data)
338337
if not self.plugin.pp_size > 1 or (
339-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
338+
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
340339
):
341340
reward = all_reduce_mean(reward.mean(), self.plugin)
342341
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
@@ -355,11 +354,11 @@ def _criterion(outputs, inputs):
355354
self.optimizer.step()
356355
self.optimizer.zero_grad()
357356
if not self.plugin.pp_size > 1 or (
358-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
357+
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
359358
):
360359
loss_scalar = self.accum_loss.item()
361360
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
362-
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
361+
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
363362
):
364363
print(
365364
"Loss:",

applications/ColossalChat/rl_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
# plugin_config={}, # for zero
122122
plugin_config={
123123
"pp_size": 2,
124-
"tp_size": 1,
124+
"tp_size": 2,
125125
"microbatch_size": args.train_microbatch_size // 2,
126126
"zero_stage": 0,
127127
"max_norm": 1.0,

0 commit comments

Comments
 (0)