Skip to content

Commit ad1c0b6

Browse files
committed
support qwen-omni grpo training recipe
Signed-off-by: root <zhangyuekai@foxmail.com>
1 parent a426896 commit ad1c0b6

File tree

15 files changed

+669
-18
lines changed

15 files changed

+669
-18
lines changed
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
grpo:
2+
num_prompts_per_step: 8
3+
num_generations_per_prompt: 4
4+
max_rollout_turns: 1
5+
max_num_epochs: 1
6+
max_num_steps: 500
7+
normalize_rewards: true
8+
use_leave_one_out_baseline: true
9+
val_period: 10
10+
val_at_start: false
11+
val_at_end: false
12+
overlong_filtering: false
13+
max_val_samples: 32
14+
val_batch_size: 32
15+
seed: 42
16+
use_dynamic_sampling: false
17+
batch_multiplier: 1
18+
reward_shaping:
19+
enabled: false
20+
overlong_buffer_length: 512
21+
overlong_buffer_penalty: 1
22+
max_response_length: ${policy.max_total_sequence_length}
23+
# Advantage Estimator Configuration
24+
# Options: "grpo" (default) or "reinforce_plus_plus"
25+
adv_estimator:
26+
name: "grpo" # Use "reinforce_plus_plus" for Reinforce++ estimator
27+
normalize_rewards: ${grpo.normalize_rewards}
28+
use_leave_one_out_baseline: ${grpo.use_leave_one_out_baseline}
29+
minus_baseline: true # Reinforce++-baseline specific: subtract per-prompt mean baseline
30+
reward_scaling:
31+
enabled: false
32+
source_min: 0.0
33+
source_max: 1.0
34+
target_min: 0.0
35+
target_max: 1.0
36+
async_grpo:
37+
enabled: false
38+
max_trajectory_age_steps: 1
39+
seq_logprob_error_threshold: null
40+
loss_fn:
41+
reference_policy_kl_penalty: 0.01
42+
# Can be set to k1, k2, k3
43+
# For more details, see http://joschu.net/blog/kl-approx.html
44+
reference_policy_kl_type: "k3"
45+
kl_input_clamp_value: 20.0
46+
kl_output_clamp_value: 10.0
47+
ratio_clip_min: 0.2
48+
ratio_clip_max: 0.2
49+
ratio_clip_c: null
50+
use_on_policy_kl_approximation: false
51+
use_importance_sampling_correction: false
52+
truncated_importance_sampling_ratio: null
53+
token_level_loss: true
54+
force_on_policy_ratio: false
55+
checkpointing:
56+
enabled: true
57+
checkpoint_dir: results/audio_grpo_3B_megatron
58+
metric_name: val:accuracy
59+
higher_is_better: true
60+
keep_top_k: 3
61+
save_period: 100
62+
checkpoint_must_save_by: null
63+
policy:
64+
model_name: /workspace_yuekai/HF/Qwen2.5-Omni-3B
65+
tokenizer:
66+
name: ${policy.model_name}
67+
train_global_batch_size: 32
68+
train_micro_batch_size: 1
69+
generation_batch_size: 32
70+
logprob_batch_size: 4
71+
max_total_sequence_length: 2048
72+
precision: bfloat16
73+
offload_optimizer_for_logprob: false
74+
dtensor_cfg:
75+
_v2: true
76+
enabled: false
77+
cpu_offload: false
78+
sequence_parallel: false
79+
activation_checkpointing: false
80+
tensor_parallel_size: 1
81+
context_parallel_size: 1
82+
custom_parallel_plan: null
83+
dynamic_batching:
84+
enabled: false
85+
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
86+
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
87+
sequence_length_round: 64
88+
make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}
89+
max_grad_norm: 1.0
90+
sequence_packing:
91+
enabled: false
92+
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
93+
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
94+
algorithm: modified_first_fit_decreasing
95+
sequence_length_round: 64
96+
scheduler:
97+
- name: torch.optim.lr_scheduler.LinearLR
98+
kwargs:
99+
start_factor: 0.1
100+
end_factor: 1.0
101+
total_iters: 50
102+
- name: torch.optim.lr_scheduler.ConstantLR
103+
kwargs:
104+
factor: 1.0
105+
total_iters: 10000000000
106+
- milestones:
107+
- 50
108+
generation:
109+
backend: vllm
110+
max_new_tokens: 1024
111+
temperature: 1.0
112+
top_p: 1.0
113+
top_k: null
114+
stop_token_ids: null
115+
stop_strings: null
116+
vllm_cfg:
117+
async_engine: false
118+
precision: ${policy.precision}
119+
kv_cache_dtype: "auto"
120+
tensor_parallel_size: 1
121+
pipeline_parallel_size: 1
122+
expert_parallel_size: 1
123+
gpu_memory_utilization: 0.6
124+
max_model_len: ${policy.max_total_sequence_length}
125+
enforce_eager: false
126+
enable_expert_parallel: false
127+
# Audio/multimodal models require tokenizer to be initialized before generation
128+
skip_tokenizer_init: False
129+
limit_mm_per_prompt:
130+
audio: 1
131+
colocated:
132+
enabled: true
133+
resources:
134+
gpus_per_node: null
135+
num_nodes: null
136+
megatron_cfg:
137+
enabled: true
138+
empty_unused_memory_level: 1
139+
activation_checkpointing: false
140+
converter_type: Qwen2_5OmniForConditionalGeneration
141+
tensor_model_parallel_size: 1
142+
expert_tensor_parallel_size: 1
143+
expert_model_parallel_size: 1
144+
pipeline_model_parallel_size: 1
145+
num_layers_in_first_pipeline_stage: null
146+
num_layers_in_last_pipeline_stage: null
147+
context_parallel_size: 1
148+
pipeline_dtype: ${policy.precision}
149+
sequence_parallel: false
150+
freeze_moe_router: true
151+
moe_router_dtype: fp64
152+
moe_router_load_balancing_type: none
153+
moe_router_bias_update_rate: 0.0
154+
moe_permute_fusion: false
155+
apply_rope_fusion: false
156+
bias_activation_fusion: True
157+
defer_fp32_logits: False
158+
moe_per_layer_logging: False
159+
moe_enable_deepep: false
160+
moe_token_dispatcher_type: "allgather"
161+
moe_shared_expert_overlap: false
162+
peft:
163+
enabled: false
164+
target_modules: []
165+
exclude_modules: []
166+
dim: 8
167+
alpha: 32
168+
dropout: 0.0
169+
dropout_position: "post"
170+
lora_A_init_method: "xavier"
171+
lora_B_init_method: "zero"
172+
a2a_experimental: false
173+
lora_dtype: null
174+
optimizer:
175+
optimizer: adam
176+
lr: 2.0e-07
177+
min_lr: 2.0e-07
178+
weight_decay: 0.01
179+
bf16: true
180+
fp16: false
181+
params_dtype: float32
182+
adam_beta1: 0.9
183+
adam_beta2: 0.999
184+
adam_eps: 1.0e-08
185+
sgd_momentum: 0.9
186+
use_distributed_optimizer: true
187+
use_precision_aware_optimizer: true
188+
clip_grad: ${policy.max_grad_norm}
189+
optimizer_cpu_offload: false
190+
optimizer_offload_fraction: 0.0
191+
scheduler:
192+
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
193+
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
194+
weight_decay_incr_style: constant
195+
lr_decay_style: constant
196+
lr_decay_iters: 1000
197+
lr_warmup_iters: 50
198+
lr_warmup_init: 2.0e-08
199+
distributed_data_parallel_config:
200+
grad_reduce_in_fp32: false
201+
overlap_grad_reduce: false
202+
overlap_param_gather: true
203+
use_custom_fsdp: false
204+
data_parallel_sharding_strategy: optim_grads_params
205+
fp8_cfg:
206+
enabled: false
207+
fp8: "e4m3"
208+
fp8_recipe: "blockwise"
209+
fp8_param: false
210+
data:
211+
max_input_seq_length: ${policy.max_total_sequence_length}
212+
shuffle: true
213+
num_workers: 1
214+
215+
# use multiple dataloader for train
216+
use_multiple_dataloader: false
217+
218+
# dataset
219+
train:
220+
dataset_name: avqa
221+
split: train
222+
validation:
223+
dataset_name: avqa
224+
split: validation
225+
# default settings for all datasets
226+
default:
227+
prompt_file: examples/prompts/avqa_cot.txt
228+
system_prompt_file: null
229+
processor: "vlm_hf_data_processor"
230+
env_name: "avqa"
231+
env:
232+
avqa:
233+
num_workers: 8
234+
reward_functions:
235+
- name: format
236+
weight: 0.2
237+
- name: exact_alnum
238+
weight: 0.8
239+
logger:
240+
log_dir: logs
241+
num_val_samples_to_print: 0
242+
wandb_enabled: true
243+
tensorboard_enabled: true
244+
swanlab_enabled: false
245+
mlflow_enabled: false
246+
monitor_gpus: false
247+
wandb:
248+
project: grpo-dev
249+
name: audio-grpo-3b-megatron
250+
swanlab:
251+
project: grpo-dev
252+
name: audio-grpo-3b-megatron
253+
tensorboard: {}
254+
gpu_monitoring:
255+
collection_interval: 10
256+
flush_interval: 10
257+
cluster:
258+
gpus_per_node: 8
259+
num_nodes: 1

0 commit comments

Comments
 (0)