Skip to content

Commit e589ec5

Browse files
committed
support resume training
1 parent 08a1244 commit e589ec5

File tree

5 files changed

+98
-10
lines changed

5 files changed

+98
-10
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from contextlib import nullcontext
32
from typing import Any, Dict, Optional
43

@@ -7,11 +6,13 @@
76
import torch
87
import torch.distributed as dist
98
from coati.distributed.profiling_utils import CustomProfiler
9+
from coati.utils import save_checkpoint
1010
from tqdm import tqdm
1111
from transformers import AutoModelForCausalLM
1212

1313
from colossalai.booster import Booster
1414
from colossalai.booster.plugin import HybridParallelPlugin
15+
from colossalai.cluster import DistCoordinator
1516
from colossalai.initialize import launch
1617
from colossalai.nn.optimizer import HybridAdam
1718
from colossalai.utils import get_current_device
@@ -55,16 +56,19 @@ def __init__(
5556
self.enable_profiling = enable_profiling
5657
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
5758
self.num_microbatches = batch_size // minibatch_size
59+
self.checkpoint_path = model_config.pop("checkpoint_path", None)
5860

5961
self.model_config = model_config
6062
self.plugin_config = plugin_config
6163

6264
self.device = get_current_device()
6365
self.lr_scheduler = None
6466
self.n_behind = n_behind
67+
self.total_prompt_trained = 0 # for setting start index when resume training
6568

6669
def setup(self) -> None:
6770
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
71+
self.coordinator = DistCoordinator()
6872

6973
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
7074
if (
@@ -143,6 +147,26 @@ def calculate_effective_group_to_raw_group_mapping(self, step):
143147
return effective_group_to_raw_group_mapping
144148

145149
def loop(self) -> None:
150+
self.profiler.enter("sync_model")
151+
torch.cuda.empty_cache()
152+
state_dict = self.state_dict()
153+
if self.pp_size > 1:
154+
if self.tp_rank == 0 and self.dp_rank == 0:
155+
ray_broadcast_tensor_dict(
156+
state_dict,
157+
src=self.num_producers,
158+
device=self.device,
159+
group_name=f"sync_model_{self.pp_rank}",
160+
)
161+
else:
162+
if self.rank == 0:
163+
ray_broadcast_tensor_dict(
164+
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
165+
)
166+
del state_dict
167+
torch.cuda.empty_cache()
168+
self.profiler.exit("sync_model")
169+
146170
print(
147171
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
148172
)
@@ -208,6 +232,7 @@ def loop(self) -> None:
208232
for k, v in raw_batch.items()
209233
}
210234
# [batch_size, num_generations] -> [batch_size]
235+
self.total_prompt_trained += raw_batch["reward"].size(0)
211236
reward = raw_batch["reward"][:, :, 0]
212237
format_acc = raw_batch["format_acc"][:, :, 0]
213238
ans_acc = raw_batch["ans_acc"][:, :, 0]
@@ -285,10 +310,19 @@ def loop(self) -> None:
285310
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
286311
if self.rank == 0:
287312
print(f"Start saving policy model at step {step + 1}.")
288-
save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}")
289-
self.booster.save_model(self.policy_model, save_path, shard=True)
313+
save_checkpoint(
314+
save_dir=self.save_dir,
315+
booster=self.booster,
316+
model=self.policy_model,
317+
optimizer=self.optimizer,
318+
lr_scheduler=self.lr_scheduler,
319+
epoch=episode,
320+
step=step,
321+
batch_size=int(self.total_prompt_trained / step),
322+
coordinator=self.coordinator,
323+
) # for setting start index when resuming training
290324
if self.rank == 0:
291-
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
325+
print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")
292326

293327
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
294328
episode != 0 or step >= self.n_behind

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from coati.distributed.loss import PolicyLoss
99
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
1010
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
11+
from coati.utils import load_checkpoint
1112
from transformers import AutoModelForCausalLM, AutoTokenizer
1213

1314
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -157,6 +158,14 @@ def setup(self):
157158
)
158159
if self.policy_loss_fn.beta > 0:
159160
self.reference_model, *_ = self.booster.boost(self.reference_model)
161+
if self.checkpoint_path is not None:
162+
load_checkpoint(
163+
self.checkpoint_path,
164+
self.booster,
165+
self.policy_model,
166+
self.optimizer,
167+
self.lr_scheduler,
168+
)
160169
self.plugin.logger.set_level("ERROR")
161170

162171
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import torch
99
import tqdm
1010
import wandb
11+
from coati.dataset import StatefulDistributedSampler
1112
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
1213
from coati.distributed.profiling_utils import CustomProfiler
1314
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
1415
from coati.distributed.reward.verifiable_reward import VerifiableReward
16+
from coati.utils import load_checkpoint
1517
from ray.util.collective import allreduce
1618
from ray.util.collective.types import Backend, ReduceOp
1719
from torch.utils.data import DataLoader, DistributedSampler
@@ -68,6 +70,7 @@ def __init__(
6870
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
6971

7072
self.train_dataset_config = train_dataset_config
73+
self.checkpoint_path = model_config.pop("checkpoint_path", None)
7174
self.model_config = model_config
7275
self.generate_config = generate_config
7376
self.tokenizer_config = tokenizer_config
@@ -121,7 +124,7 @@ def __init__(
121124
self.train_dataloader = DataLoader(
122125
self.train_dataset,
123126
batch_size=microbatch_size,
124-
sampler=DistributedSampler(
127+
sampler=StatefulDistributedSampler(
125128
self.train_dataset,
126129
num_replicas=num_producers,
127130
rank=producer_idx,
@@ -133,6 +136,13 @@ def __init__(
133136
drop_last=True,
134137
collate_fn=collate_fn_grpo,
135138
)
139+
if self.checkpoint_path is not None:
140+
# resume training from checkpoint
141+
start_epoch, start_step, sampler_start_idx = load_checkpoint(self.checkpoint_path, None, None, None, None)
142+
self.train_dataloader.sampler.set_start_index(sampler_start_idx)
143+
print(
144+
f"[P{self.producer_idx}] Resume training from checkpoint {self.checkpoint_path}, start epoch {start_epoch}, start step {start_step}, sampler start index {sampler_start_idx}"
145+
)
136146
if grpo_config["reward_fn_type"] == "think_answer_tags":
137147
self.evaluation_function = math_reward_fn
138148
elif grpo_config["reward_fn_type"] == "boxed":
@@ -203,6 +213,29 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
203213
raise NotImplementedError
204214

205215
def loop(self) -> None:
216+
217+
torch.cuda.empty_cache()
218+
self.profiler.enter("sync_model")
219+
if self.consumer_pp_size > 1:
220+
for pp_idx in range(self.consumer_pp_size):
221+
state_dict = ray_broadcast_tensor_dict(
222+
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
223+
)
224+
if "consumer_global_step" in state_dict:
225+
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
226+
self.load_state_dict(state_dict)
227+
else:
228+
state_dict = ray_broadcast_tensor_dict(
229+
None, self.num_producers, device=self.device, group_name="sync_model"
230+
)
231+
if "consumer_global_step" in state_dict:
232+
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
233+
self.load_state_dict(state_dict)
234+
self.profiler.exit("sync_model")
235+
print(f"[P{self.producer_idx}] Sync initial model done.")
236+
del state_dict
237+
torch.cuda.empty_cache()
238+
206239
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
207240
num_valid_microbatches = num_update_per_episode * self.num_microbatches
208241

applications/ColossalChat/coati/utils/ckpt_io.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,12 @@ def load_checkpoint(
8181
"""
8282

8383
# Update booster params states.
84-
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
85-
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
86-
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
84+
if model is not None:
85+
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
86+
if optimizer is not None:
87+
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
88+
if lr_scheduler is not None:
89+
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
8790

8891
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
8992
return (

applications/ColossalChat/rl_example.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
if __name__ == "__main__":
1919
parser = argparse.ArgumentParser()
2020
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
21+
parser.add_argument(
22+
"-cp",
23+
"--checkpoint-path",
24+
type=str,
25+
default=None,
26+
help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.",
27+
)
2128
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
2229
parser.add_argument(
2330
"-ed",
@@ -226,8 +233,10 @@
226233

227234
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
228235

229-
inference_model_config = dict(path=args.model)
230-
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
236+
inference_model_config = dict(path=args.model, checkpoint_path=args.checkpoint_path)
237+
train_model_config = dict(
238+
path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path
239+
)
231240
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
232241

233242
if args.backend == "transformers":

0 commit comments

Comments
 (0)