Skip to content

Commit a84f3b4

Browse files
authored
test: add non-colocated nightly test (#960)
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent d168de3 commit a84f3b4

27 files changed

+183
-24
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
grpo:
2+
num_prompts_per_step: 64
3+
num_generations_per_prompt: 32
4+
max_rollout_turns: 1
5+
max_num_steps: 500
6+
normalize_rewards: true
7+
use_leave_one_out_baseline: true
8+
val_period: 10
9+
val_at_start: false
10+
max_val_samples: 256
11+
val_batch_size: 256
12+
seed: 42
13+
loss_fn:
14+
reference_policy_kl_penalty: 0.01
15+
ratio_clip_min: 0.2
16+
ratio_clip_max: 0.2
17+
ratio_clip_c: null
18+
use_on_policy_kl_approximation: false
19+
use_importance_sampling_correction: false
20+
token_level_loss: true
21+
checkpointing:
22+
enabled: true
23+
checkpoint_dir: results/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated
24+
metric_name: val_reward
25+
higher_is_better: true
26+
keep_top_k: 3
27+
save_period: 10
28+
checkpoint_must_save_by: null
29+
policy:
30+
model_name: meta-llama/Llama-3.1-8B-Instruct
31+
tokenizer:
32+
name: meta-llama/Llama-3.1-8B-Instruct
33+
train_global_batch_size: 512
34+
train_micro_batch_size: 1
35+
generation_batch_size: 32
36+
logprob_batch_size: 2
37+
max_total_sequence_length: 4096
38+
precision: bfloat16
39+
dtensor_cfg:
40+
enabled: true
41+
cpu_offload: false
42+
sequence_parallel: false
43+
activation_checkpointing: false
44+
tensor_parallel_size: 1
45+
context_parallel_size: 1
46+
custom_parallel_plan: null
47+
dynamic_batching:
48+
enabled: True
49+
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
50+
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
51+
sequence_length_round: 64
52+
sequence_packing:
53+
enabled: false
54+
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
55+
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
56+
algorithm: "modified_first_fit_decreasing"
57+
sequence_length_round: 64
58+
make_sequence_length_divisible_by: 1
59+
max_grad_norm: 1
60+
optimizer:
61+
name: torch.optim.AdamW
62+
kwargs:
63+
lr: 3e-07
64+
weight_decay: 0.01
65+
betas:
66+
- 0.9
67+
- 0.999
68+
eps: 1e-08
69+
foreach: false
70+
fused: false
71+
scheduler:
72+
- name: torch.optim.lr_scheduler.LinearLR
73+
kwargs:
74+
start_factor: 0.1
75+
end_factor: 1
76+
total_iters: 13
77+
- name: torch.optim.lr_scheduler.ConstantLR
78+
kwargs:
79+
factor: 1
80+
total_iters: 10000000000
81+
- milestones:
82+
- 13
83+
generation:
84+
backend: vllm
85+
max_new_tokens: 4096
86+
temperature: 1
87+
top_p: 1
88+
top_k: null
89+
stop_token_ids:
90+
- 128009
91+
stop_strings: null
92+
vllm_cfg:
93+
async_engine: true
94+
precision: ${policy.precision}
95+
tensor_parallel_size: 1
96+
pipeline_parallel_size: 1
97+
gpu_memory_utilization: 0.6
98+
max_model_len: 4096
99+
enforce_eager: False
100+
colocated:
101+
enabled: false
102+
resources:
103+
gpus_per_node: null
104+
num_nodes: 1
105+
data:
106+
max_input_seq_length: 4096
107+
prompt_file: examples/prompts/cot.txt
108+
system_prompt_file: null
109+
dataset_name: OpenMathInstruct-2
110+
shuffle: true
111+
env:
112+
math:
113+
num_workers: 8
114+
logger:
115+
log_dir: logs/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated
116+
num_val_samples_to_print: 0
117+
wandb_enabled: true
118+
tensorboard_enabled: true
119+
mlflow_enabled: false
120+
monitor_gpus: true
121+
wandb:
122+
project: nemo-rl
123+
name: grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated
124+
tensorboard: {}
125+
gpu_monitoring:
126+
collection_interval: 10
127+
flush_interval: 10
128+
cluster:
129+
gpus_per_node: 8
130+
num_nodes: 2

tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
4040
'data["train/preference_loss"]["1"] < 0.69316' \
4141
'data["train/preference_loss"]["20"] < 0.6' \
4242
'mean(data["timing/train/total_step_time"], -10, -1) < 7.8'
43-
fi
43+
fi

tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp4.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
4040
'data["train/preference_loss"]["1"] < 0.69316' \
4141
'data["train/preference_loss"]["150"] < 0.4' \
4242
'mean(data["timing/train/total_step_time"], -11, -1) < 24'
43-
fi
43+
fi

tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.v2.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
4040
'data["train/preference_loss"]["1"] < 0.69316' \
4141
'data["train/preference_loss"]["150"] < 0.4' \
4242
'mean(data["timing/train/total_step_time"], -11, -1) < 11.5'
43-
fi
43+
fi

tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
4040
'data["train/preference_loss"]["1"] < 0.69316' \
4141
'data["train/preference_loss"]["20"] < 0.6' \
4242
'mean(data["timing/train/total_step_time"], -10) < 6.7'
43-
fi
43+
fi

tests/test_suites/llm/dpo-llama3.1-8b-tulu3-1n8g-fsdp2tp1.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
4040
'data["train/preference_loss"]["1"] < 0.6932' \
4141
'data["train/preference_loss"]["150"] < 0.68'
4242
fi
43-

tests/test_suites/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
3838
'data["train/loss"]["1"] < 0.69316' \
3939
'data["train/loss"]["150"] < 0.55' \
4040
'mean(data["timing/train/total_step_time"], -11, -1) < 1.3'
41-
fi
41+
fi

tests/test_suites/llm/grpo-deepscaler-1.5b-16K.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,3 @@ cat ${RUN_LOG}.aime-16k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"sco
6666
# 240 step checkpoint 0.3
6767
uv run tests/check_metrics.py ${RUN_LOG}-16k-metric.json \
6868
'data["score"] >= 0.2396'
69-

tests/test_suites/llm/grpo-deepscaler-1.5b-24K.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,3 @@ cat ${RUN_LOG}.aime-24k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"sco
6565

6666
uv run tests/check_metrics.py ${RUN_LOG}-24k-metric.json \
6767
'data["score"] >= 0.2396'
68-

tests/test_suites/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
3737
'mean(data["train/token_mult_prob_error"]) < 1.1' \
3838
'data["train/token_mult_prob_error"]["20"] < 1.1'
3939
fi
40-

0 commit comments

Comments
 (0)