Skip to content

Commit ac0897f

Browse files
committed
Reimplement top-p and top-k from #1578
Signed-off-by: Zhanda <zhandazhu@gmail.com>
1 parent 336803f commit ac0897f

17 files changed

+1632
-291
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
max_num_steps: 500
4+
checkpointing:
5+
enabled: false
6+
checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron
7+
save_period: 100
8+
policy:
9+
model_name: meta-llama/Llama-3.2-1B-Instruct
10+
tokenizer:
11+
name: meta-llama/Llama-3.2-1B-Instruct
12+
optimizer: null
13+
megatron_cfg:
14+
enabled: true
15+
scheduler:
16+
lr_warmup_iters: 50
17+
dtensor_cfg:
18+
enabled: false
19+
make_sequence_length_divisible_by: 1
20+
generation:
21+
max_new_tokens: 512
22+
vllm_cfg:
23+
max_model_len: 512
24+
temperature: 0.8
25+
top_p: 0.9
26+
top_k: 50
27+
data:
28+
max_input_seq_length: 512
29+
logger:
30+
log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron
31+
wandb_enabled: true
32+
tensorboard_enabled: true
33+
wandb:
34+
project: nemo-rl
35+
name: grpo-llama3.2-1b-instruct-1n8g-megatron
36+
cluster:
37+
gpus_per_node: 8
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
max_num_steps: 500
4+
checkpointing:
5+
enabled: false
6+
checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron
7+
save_period: 100
8+
policy:
9+
model_name: meta-llama/Llama-3.2-1B-Instruct
10+
tokenizer:
11+
name: meta-llama/Llama-3.2-1B-Instruct
12+
optimizer: null
13+
megatron_cfg:
14+
enabled: true
15+
scheduler:
16+
lr_warmup_iters: 50
17+
dtensor_cfg:
18+
enabled: false
19+
make_sequence_length_divisible_by: 1
20+
generation:
21+
max_new_tokens: 512
22+
vllm_cfg:
23+
max_model_len: 512
24+
temperature: 0.6
25+
data:
26+
max_input_seq_length: 512
27+
logger:
28+
log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron
29+
wandb_enabled: true
30+
tensorboard_enabled: true
31+
wandb:
32+
project: nemo-rl
33+
name: grpo-llama3.2-1b-instruct-1n8g-megatron
34+
cluster:
35+
gpus_per_node: 8

nemo_rl/algorithms/loss_functions.py

Lines changed: 118 additions & 95 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)