Skip to content

Commit bf9007f

Browse files
authored
Add a single-host qwen3 8b config, auto deduce n_layers (meta-pytorch#215)
* add qwen3 8b config, auto deduce n_layers * update docstring
1 parent 435729a commit bf9007f

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed

apps/grpo/qwen3_8b.yaml

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Grouped Relative Policy Optimization (GRPO)
2+
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml
3+
4+
# Global configuration
5+
group_size: 8
6+
batch_size: 16
7+
max_req_tokens: 512
8+
max_res_tokens: 512
9+
model: "Qwen/Qwen3-8B"
10+
off_by_n: 1 # Off by one by default
11+
12+
# Dataset configuration
13+
dataset:
14+
path: "openai/gsm8k"
15+
revision: "main"
16+
data_split: "train"
17+
streaming: true
18+
model: ${model}
19+
20+
# Policy configuration
21+
policy:
22+
engine_config:
23+
model: ${model}
24+
tensor_parallel_size: 2
25+
pipeline_parallel_size: 1
26+
enforce_eager: false
27+
sampling_config:
28+
n: ${group_size}
29+
max_tokens: ${max_res_tokens}
30+
temperature: 1.0
31+
top_p: 1.0
32+
33+
# Trainer configuration
34+
trainer:
35+
model:
36+
name: qwen3
37+
flavor: 8B
38+
hf_assets_path: hf://${model}
39+
optimizer:
40+
name: AdamW
41+
lr: 1e-5
42+
eps: 1e-8
43+
lr_scheduler:
44+
warmup_steps: 1
45+
training:
46+
local_batch_size: ${batch_size}
47+
seq_len: 2048
48+
max_norm: 1.0
49+
steps: 1000000
50+
dtype: bfloat16
51+
compile:
52+
enable: false
53+
parallelism:
54+
data_parallel_replicate_degree: 1
55+
data_parallel_shard_degree: -1
56+
tensor_parallel_degree: 1
57+
pipeline_parallel_degree: 1
58+
context_parallel_degree: 1
59+
expert_parallel_degree: 1
60+
disable_loss_parallel: true
61+
checkpoint:
62+
enable: true
63+
initial_load_path: hf://${model}
64+
initial_load_in_hf: true
65+
last_save_in_hf: true
66+
interval: 500
67+
async_mode: "disabled"
68+
activation_checkpoint:
69+
mode: selective
70+
selective_ac_option: op
71+
72+
# Replay buffer configuration
73+
replay_buffer:
74+
batch_size: ${batch_size}
75+
max_policy_age: ${off_by_n}
76+
# This should match the dp_size of TorchTitan
77+
# Here it's set explicitly to 2, because we've set
78+
# 2 GPUs for the trainer and we're using full FSDP.
79+
dp_size: 2
80+
81+
# Reference model configuration
82+
ref_model:
83+
model:
84+
name: qwen3
85+
flavor: 8B
86+
hf_assets_path: hf://${model}
87+
training:
88+
dtype: bfloat16
89+
compile:
90+
enable: false
91+
parallelism:
92+
data_parallel_replicate_degree: 1
93+
data_parallel_shard_degree: 1
94+
tensor_parallel_degree: 1
95+
pipeline_parallel_degree: 1
96+
context_parallel_degree: 1
97+
expert_parallel_degree: 1
98+
checkpoint:
99+
initial_load_path: hf://${model}
100+
initial_load_in_hf: true
101+
102+
# All resource allocations
103+
services:
104+
dataset:
105+
procs: 1
106+
num_replicas: 1
107+
with_gpus: false
108+
policy:
109+
procs: ${policy.engine_config.tensor_parallel_size}
110+
num_replicas: 1
111+
with_gpus: true
112+
trainer:
113+
procs: 2
114+
num_replicas: 1
115+
with_gpus: true
116+
replay_buffer:
117+
procs: 1
118+
num_replicas: 1
119+
with_gpus: false
120+
ref_model:
121+
procs: 1
122+
num_replicas: 1
123+
with_gpus: true
124+
compute_advantages:
125+
procs: 1
126+
num_replicas: 1
127+
with_gpus: false
128+
reward_actor:
129+
procs: 1
130+
num_replicas: 1
131+
with_gpus: false

src/forge/actors/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ async def push_weights(self, policy_version: int) -> None:
264264
)
265265
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
266266
# TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
267-
vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict, num_layers=28)
267+
vllm_ready_hf_sd = _qwen3_hf_to_vllm(
268+
sd=hf_state_dict, num_layers=self.engine.model_args.n_layers
269+
)
268270

269271
key = f"{self.state_dict_key}{DELIM}{policy_version}"
270272
start_time = time.time()

0 commit comments

Comments
 (0)