Skip to content

Commit 0cc0c84

Browse files
author
Tong Li
committed
add save
1 parent 0f566cc commit 0cc0c84

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

applications/ColossalChat/coati/dataset/loader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ def apply_chat_template_and_mask(
357357
ignore_idx: int = -100,
358358
) -> Dict[str, torch.Tensor]:
359359

360-
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning.\n"
360+
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
361+
362+
361363

362364
system_element = {
363365
"role": "system",

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from contextlib import nullcontext
22
from typing import Any, Dict, Optional
3-
3+
import os
44
import ray
55
import ray.util.collective as cc
66
import torch
@@ -33,6 +33,8 @@ def __init__(
3333
model_config: Dict[str, Any],
3434
plugin_config: Dict[str, Any],
3535
microbatch_size: int = 1,
36+
save_interval: int = 100,
37+
save_dir: str = "./model"
3638
):
3739
self.num_producers = num_producers
3840
self.num_episodes = num_episodes
@@ -44,6 +46,8 @@ def __init__(
4446
self.num_recv_per_update = num_recv_per_update
4547
self.batch_size = batch_size
4648
self.microbatch_size = microbatch_size
49+
self.save_interval = save_interval
50+
self.save_dir = save_dir
4751
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
4852
self.num_microbatches = batch_size // microbatch_size
4953

@@ -116,6 +120,14 @@ def loop(self) -> None:
116120
pbar.set_postfix({"loss": loss})
117121
i += 1
118122
assert len(self.buffer) == 0
123+
if (step + 1) % self.save_interval == 0:
124+
if self.rank == 0:
125+
print(f"Start saving policy model at step {step + 1}.")
126+
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
127+
self.booster.save_model(self.policy_model, save_path, shard=True)
128+
if self.rank == 0:
129+
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
130+
119131
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
120132
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
121133
state_dict = self.state_dict()

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
plugin_config,
3333
microbatch_size=1,
3434
num_generations=4,
35-
use_wandb=False,
35+
use_wandb=True,
3636
):
3737
super().__init__(
3838
num_producers,
@@ -79,7 +79,7 @@ def __init__(
7979

8080
self.policy_loss_fn = PolicyLoss()
8181
self.global_step = 0
82-
if self.rank == 0:
82+
if use_wandb and self.rank == 0:
8383
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
8484

8585
def setup(self):

0 commit comments

Comments
 (0)