Skip to content

Commit 2ca1e3c

Browse files
authored
fix pp+tp, fix dataloader (#6280)
1 parent 28795f5 commit 2ca1e3c

File tree

5 files changed

+17
-8
lines changed

5 files changed

+17
-8
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ def setup(self) -> None:
6666
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
6767

6868
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
69-
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
69+
if (
70+
self.plugin_config.get("pp_size", 1) > 1
71+
and "num_microbatches" not in self.plugin_config
72+
and "microbatch_size" not in self.plugin_config
73+
):
7074
plugin_config["microbatch_size"] = self.minibatch_size
7175
plugin_config.update(self.plugin_config)
7276
self.plugin = HybridParallelPlugin(**plugin_config)

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def __init__(
4949
UserWarning,
5050
)
5151
minibatch_size = batch_size
52+
if (
53+
plugin_config.get("pp_size", 1) > 1
54+
and "num_microbatches" not in plugin_config
55+
and "microbatch_size" not in plugin_config
56+
):
57+
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
5258
super().__init__(
5359
num_producers,
5460
num_episodes,
@@ -373,7 +379,7 @@ def _criterion(outputs, inputs):
373379
loss_mask=inputs["loss_mask"],
374380
total_effective_tokens_in_batch=total_effective_tokens_count,
375381
)
376-
return loss, num_excessive_samples // self.num_generations
382+
return loss
377383

378384
policy_model_outputs = self.booster.execute_pipeline(
379385
iter([data_policy_forward]),
@@ -468,10 +474,10 @@ def _criterion(outputs, inputs):
468474
sample_utilization = self.effective_sample_count / self.total_sample_count
469475
self.effective_sample_count = 0
470476
self.total_sample_count = 0
477+
loss_scalar = self.accum_loss.item()
471478
if not self.plugin.pp_size > 1 or (
472479
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
473480
):
474-
loss_scalar = self.accum_loss.item()
475481
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
476482
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
477483
):
@@ -507,7 +513,7 @@ def _criterion(outputs, inputs):
507513
self.accum_advantages.zero_()
508514
self.accum_response_length.zero_()
509515
self.accum_count = 0
510-
return loss_scalar, num_excessive_samples // self.num_generations
516+
return loss_scalar, num_excessive_samples // self.num_generations
511517
else:
512518
return None, num_excessive_samples // self.num_generations
513519

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def launch_distributed(
3333
inference_batch_size: int,
3434
inference_microbatch_size: int,
3535
train_batch_size: int,
36-
train_microbatch_size: int,
3736
train_minibatch_size: int,
3837
dataset_config: Dict[str, Any],
3938
dataloaders_config: Dict[str, Any],

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
seed=42,
6969
),
7070
num_workers=4,
71+
drop_last=True,
7172
)
7273
self.device = get_current_device()
7374

@@ -116,7 +117,6 @@ def loop(self) -> None:
116117
ray_broadcast_tensor_dict(
117118
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
118119
)
119-
120120
if (i + 1) % self.num_microbatches == 0 and (
121121
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
122122
):

applications/ColossalChat/rl_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@
198198
inference_microbatch_size=args.inference_microbatch_size,
199199
train_batch_size=args.train_batch_size,
200200
train_minibatch_size=args.train_minibatch_size,
201-
train_microbatch_size=args.train_microbatch_size,
202201
dataset_config={
203202
"path": args.dataset,
204203
"max_length": args.max_prompt_tokens,
@@ -216,7 +215,8 @@
216215
# currently not support tp/pp
217216
# plugin_config={
218217
# "tp_size": 2,
219-
# "microbatch_size": args.train_microbatch_size // 2,
218+
# "pp_size": 2,
219+
# "microbatch_size": max(1, args.train_microbatch_size // 2),
220220
# "zero_stage": 0,
221221
# "max_norm": 1.0,
222222
# }, # for pp

0 commit comments

Comments
 (0)