Skip to content

Commit 26d859f

Browse files
YeAnbangTong Lipre-commit-ci[bot]
authored
[feat] Support DAPO (#6263)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache * add DAPO support * remove format reward * fix filtering, still buggy * small fix * add DAPO support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tested multi-node training; fix bind_batch bug * fix conversation; support sleep mode * support reusing excessive samples * add dynamic batching control flag * add dynamic batching control flag * refactored * fix logging --------- Co-authored-by: Tong Li <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b823c6e commit 26d859f

File tree

10 files changed

+552
-359
lines changed

10 files changed

+552
-359
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from colossalai.utils import get_current_device
1717

1818
from .comm import ray_broadcast_tensor_dict
19-
from .utils import bind_batch, post_recv, unbind_batch
19+
from .utils import bind_batch, pad_batch, post_recv, unbind_batch
2020

2121

2222
class BaseConsumer:
@@ -33,7 +33,7 @@ def __init__(
3333
batch_size: int,
3434
model_config: Dict[str, Any],
3535
plugin_config: Dict[str, Any],
36-
microbatch_size: int = 1,
36+
minibatch_size: int = 1,
3737
save_interval: int = 100,
3838
save_dir: str = "./model",
3939
):
@@ -46,11 +46,11 @@ def __init__(
4646
self.num_update_per_episode = num_update_per_episode
4747
self.num_recv_per_update = num_recv_per_update
4848
self.batch_size = batch_size
49-
self.microbatch_size = microbatch_size
49+
self.minibatch_size = minibatch_size
5050
self.save_interval = save_interval
5151
self.save_dir = save_dir
52-
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
53-
self.num_microbatches = batch_size // microbatch_size
52+
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
53+
self.num_microbatches = batch_size // minibatch_size
5454

5555
self.model_config = model_config
5656
self.plugin_config = plugin_config
@@ -67,7 +67,7 @@ def setup(self) -> None:
6767

6868
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
6969
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
70-
plugin_config["microbatch_size"] = self.microbatch_size
70+
plugin_config["microbatch_size"] = self.minibatch_size
7171
plugin_config.update(self.plugin_config)
7272
self.plugin = HybridParallelPlugin(**plugin_config)
7373
self.booster = Booster(plugin=self.plugin)
@@ -105,18 +105,26 @@ def loop(self) -> None:
105105
)
106106
)
107107
)
108-
while len(self.buffer) >= self.dp_size * self.microbatch_size:
108+
while len(self.buffer) >= self.dp_size * self.minibatch_size:
109109
batches = self.buffer[
110-
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
110+
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
111111
]
112-
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
112+
batch = pad_batch(
113+
batches
114+
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
113115
batch = bind_batch(batches)
114116
batch = post_recv(batch)
115-
loss = self.step(i, **batch)
117+
loss, num_excessive_prompts = self.step(i, pbar, **batch)
118+
self.buffer = (
119+
self.buffer[
120+
(self.dp_rank + 1) * self.minibatch_size
121+
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
122+
]
123+
+ self.buffer[self.dp_size * self.minibatch_size :]
124+
)
116125
if loss is not None:
117126
pbar.set_postfix({"loss": loss})
118127
i += 1
119-
assert len(self.buffer) == 0
120128
if self.lr_scheduler is not None:
121129
self.lr_scheduler.step()
122130
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
@@ -154,7 +162,9 @@ def __init__(
154162
batch_size,
155163
model_config,
156164
plugin_config,
157-
microbatch_size=1,
165+
minibatch_size=1,
166+
save_interval: int = 100,
167+
save_dir="./model",
158168
):
159169
super().__init__(
160170
num_producers,
@@ -168,7 +178,7 @@ def __init__(
168178
batch_size,
169179
model_config,
170180
plugin_config,
171-
microbatch_size,
181+
minibatch_size,
172182
)
173183
path = model_config.pop("path")
174184
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -181,7 +191,7 @@ def setup(self):
181191
super().setup()
182192
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
183193

184-
def step(self, step_idx: int, **kwargs) -> Optional[float]:
194+
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
185195
labels = kwargs["input_ids"].clone()
186196
labels[kwargs["attention_mask"] == 0] = -100
187197
kwargs["labels"] = labels

0 commit comments

Comments
 (0)