|
1 | | -import os |
2 | 1 | from contextlib import nullcontext |
3 | 2 | from typing import Any, Dict, Optional |
4 | 3 |
|
|
7 | 6 | import torch |
8 | 7 | import torch.distributed as dist |
9 | 8 | from coati.distributed.profiling_utils import CustomProfiler |
| 9 | +from coati.utils import save_checkpoint |
10 | 10 | from tqdm import tqdm |
11 | 11 | from transformers import AutoModelForCausalLM |
12 | 12 |
|
13 | 13 | from colossalai.booster import Booster |
14 | 14 | from colossalai.booster.plugin import HybridParallelPlugin |
| 15 | +from colossalai.cluster import DistCoordinator |
15 | 16 | from colossalai.initialize import launch |
16 | 17 | from colossalai.nn.optimizer import HybridAdam |
17 | 18 | from colossalai.utils import get_current_device |
@@ -55,16 +56,19 @@ def __init__( |
55 | 56 | self.enable_profiling = enable_profiling |
56 | 57 | assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" |
57 | 58 | self.num_microbatches = batch_size // minibatch_size |
| 59 | + self.checkpoint_path = model_config.pop("checkpoint_path", None) |
58 | 60 |
|
59 | 61 | self.model_config = model_config |
60 | 62 | self.plugin_config = plugin_config |
61 | 63 |
|
62 | 64 | self.device = get_current_device() |
63 | 65 | self.lr_scheduler = None |
64 | 66 | self.n_behind = n_behind |
| 67 | + self.total_prompt_trained = 0 # for setting start index when resume training |
65 | 68 |
|
66 | 69 | def setup(self) -> None: |
67 | 70 | launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) |
| 71 | + self.coordinator = DistCoordinator() |
68 | 72 |
|
69 | 73 | plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) |
70 | 74 | if ( |
@@ -143,6 +147,26 @@ def calculate_effective_group_to_raw_group_mapping(self, step): |
143 | 147 | return effective_group_to_raw_group_mapping |
144 | 148 |
|
145 | 149 | def loop(self) -> None: |
| 150 | + self.profiler.enter("sync_model") |
| 151 | + torch.cuda.empty_cache() |
| 152 | + state_dict = self.state_dict() |
| 153 | + if self.pp_size > 1: |
| 154 | + if self.tp_rank == 0 and self.dp_rank == 0: |
| 155 | + ray_broadcast_tensor_dict( |
| 156 | + state_dict, |
| 157 | + src=self.num_producers, |
| 158 | + device=self.device, |
| 159 | + group_name=f"sync_model_{self.pp_rank}", |
| 160 | + ) |
| 161 | + else: |
| 162 | + if self.rank == 0: |
| 163 | + ray_broadcast_tensor_dict( |
| 164 | + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" |
| 165 | + ) |
| 166 | + del state_dict |
| 167 | + torch.cuda.empty_cache() |
| 168 | + self.profiler.exit("sync_model") |
| 169 | + |
146 | 170 | print( |
147 | 171 | f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" |
148 | 172 | ) |
@@ -208,6 +232,7 @@ def loop(self) -> None: |
208 | 232 | for k, v in raw_batch.items() |
209 | 233 | } |
210 | 234 | # [batch_size, num_generations] -> [batch_size] |
| 235 | + self.total_prompt_trained += raw_batch["reward"].size(0) |
211 | 236 | reward = raw_batch["reward"][:, :, 0] |
212 | 237 | format_acc = raw_batch["format_acc"][:, :, 0] |
213 | 238 | ans_acc = raw_batch["ans_acc"][:, :, 0] |
@@ -285,10 +310,19 @@ def loop(self) -> None: |
285 | 310 | if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: |
286 | 311 | if self.rank == 0: |
287 | 312 | print(f"Start saving policy model at step {step + 1}.") |
288 | | - save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}") |
289 | | - self.booster.save_model(self.policy_model, save_path, shard=True) |
| 313 | + save_checkpoint( |
| 314 | + save_dir=self.save_dir, |
| 315 | + booster=self.booster, |
| 316 | + model=self.policy_model, |
| 317 | + optimizer=self.optimizer, |
| 318 | + lr_scheduler=self.lr_scheduler, |
| 319 | + epoch=episode, |
| 320 | + step=step, |
| 321 | + batch_size=int(self.total_prompt_trained / step), |
| 322 | + coordinator=self.coordinator, |
| 323 | + ) # for setting start index when resuming training |
290 | 324 | if self.rank == 0: |
291 | | - print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") |
| 325 | + print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}") |
292 | 326 |
|
293 | 327 | if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and ( |
294 | 328 | episode != 0 or step >= self.n_behind |
|
0 commit comments