Skip to content

Commit 40310bd

Browse files
[feat] Update consumer init to run 32B , update qwen benchmark.
1 parent ad1ceb0 commit 40310bd

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def __init__(
6969
enable_profiling=enable_profiling,
7070
n_behind=n_behind,
7171
)
72-
path = model_config.pop("path")
73-
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
72+
self.path = model_config.pop("path")
73+
self.policy_model = AutoModelForCausalLM.from_pretrained(self.path, **model_config)
7474
self.policy_model.train()
7575
self.policy_model.gradient_checkpointing_enable()
7676
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
@@ -98,12 +98,7 @@ def __init__(
9898
loss_variation=grpo_config.get("loss_variation", "sample_level"),
9999
)
100100

101-
# Reference model is initialized from policy model.
102-
if self.policy_loss_fn.beta > 0:
103-
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
104-
self.reference_model.eval()
105-
106-
self.tokenizer = AutoTokenizer.from_pretrained(path)
101+
self.tokenizer = AutoTokenizer.from_pretrained(self.path)
107102
self.pad_token_id = self.tokenizer.pad_token_id
108103
self.num_generations = num_generations
109104
self.filter_range = grpo_config.get("filter_range", None)
@@ -148,7 +143,10 @@ def setup(self):
148143
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
149144
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
150145
)
146+
# Reference model is initialized from policy model.
151147
if self.policy_loss_fn.beta > 0:
148+
self.reference_model = AutoModelForCausalLM.from_pretrained(self.path, **self.model_config)
149+
self.reference_model.eval()
152150
self.reference_model, *_ = self.booster.boost(self.reference_model)
153151
self.plugin.logger.set_level("ERROR")
154152

examples/language/qwen2/benchmark.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def main():
5353
# ==============================
5454
parser = argparse.ArgumentParser()
5555
parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
56-
parser.add_argument("-model", "--model_path", type=str, help="Model path")
56+
parser.add_argument("--model_path", type=str, help="Model path")
5757
parser.add_argument(
5858
"-p",
5959
"--plugin",
@@ -85,6 +85,7 @@ def main():
8585
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
8686
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
8787
parser.add_argument("--profile", action="store_true", help="Profile the code")
88+
parser.add_argument("--cpu_offload", action="store_true", help="Cpu offload")
8889
parser.add_argument(
8990
"--nsys",
9091
action="store_true",
@@ -142,6 +143,7 @@ def empty_init():
142143
pp_style=args.pp_style,
143144
num_model_chunks=args.n_chunks,
144145
zero_stage=args.zero,
146+
cpu_offload=args.cpu_offload,
145147
sp_size=args.sp,
146148
sequence_parallelism_mode=args.sp_mode,
147149
enable_sequence_parallelism=args.sp > 1,
@@ -204,7 +206,11 @@ def empty_init():
204206
)
205207

206208
model = Qwen2ForCausalLM.from_pretrained(
207-
MODEL_PATH, trust_remote_code=True, use_flash_attention_2=False, use_cache=False, attn_implementation="eager"
209+
args.model_path,
210+
trust_remote_code=True,
211+
use_flash_attention_2=False,
212+
use_cache=False,
213+
attn_implementation="eager",
208214
)
209215
if args.grad_checkpoint:
210216
model.gradient_checkpointing_enable()

examples/language/qwen2/hybrid_test_N1C8.sh

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,14 @@
66

77
export OMP_NUM_THREADS=8
88

9-
#hybird: zero2+flash_atten+grad_ckpt+bs4
10-
colossalai run --nproc_per_node 8 benchmark.py -m "/home/grpo/models/Qwen2.5-7B/" -p "3d" -x -g --zero 1 -b 32 --mbs 1 --tp 2 --pp 2 -l 4096
9+
colossalai run --nproc_per_node 8 benchmark.py \
10+
--model_path "/home/grpo/models/DeepSeek-R1-Distill-Qwen-7B/" \
11+
-p "3d" \
12+
-x -g \
13+
--zero 1 \
14+
--cpu_offload \
15+
-b 16 --mbs 1 \
16+
--tp 4 --pp 2 \
17+
-l 4096 \
18+
-s 3 \
19+
&>qwen2_7b.log &

0 commit comments

Comments
 (0)