Skip to content

Commit c784dd9

Browse files
feat: add data shuffle and random seed option (#334)
Signed-off-by: Zhiyu Li <zhiyul@nvidia.com> Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent c249efc commit c784dd9

36 files changed

+212
-12
lines changed

examples/configs/dpo.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ policy:
150150
data:
151151
dataset_name: "HelpSteer3"
152152
max_input_seq_length: ${policy.max_total_sequence_length}
153+
shuffle: true
153154
logger:
154155
log_dir: "logs" # Base directory for all logs
155156
wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running

examples/configs/grpo-deepscaler-1.5b-8K.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ grpo:
1010
val_at_start: false
1111
max_val_samples: 480
1212
val_batch_size: 32
13+
seed: 42
1314

1415
loss_fn:
1516
reference_policy_kl_penalty: 0.0
@@ -118,6 +119,7 @@ data:
118119
prompt_file: "examples/prompts/cot.txt"
119120
system_prompt_file: null
120121
dataset_name: "DeepScaler"
122+
shuffle: true
121123

122124
env:
123125
math:

examples/configs/grpo_math_1B.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ grpo:
1010
val_at_start: false
1111
max_val_samples: 256
1212
val_batch_size: 256
13+
seed: 42
1314

1415
loss_fn:
1516
reference_policy_kl_penalty: 0.01
@@ -127,6 +128,7 @@ data:
127128
prompt_file: "examples/prompts/cot.txt"
128129
system_prompt_file: null
129130
dataset_name: "OpenMathInstruct-2"
131+
shuffle: true
130132

131133
env:
132134
math:

examples/configs/grpo_math_1B_megatron.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ data:
146146
prompt_file: "examples/prompts/cot.txt"
147147
system_prompt_file: null
148148
dataset_name: "OpenMathInstruct-2"
149+
shuffle: true
149150

150151
env:
151152
math:

examples/configs/grpo_sliding_puzzle.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ policy:
4444

4545
data:
4646
add_system_prompt: false
47+
shuffle: false # disable dataloader shuffle, shuffle is handled within the dataset
4748

4849
env:
4950
sliding_puzzle_game:

examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ policy:
7373
data:
7474
dataset_name: "HelpSteer3"
7575
max_input_seq_length: ${policy.max_total_sequence_length}
76+
shuffle: true
7677

7778
logger:
7879
log_dir: "logs"

examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ policy:
7373
data:
7474
dataset_name: "HelpSteer3"
7575
max_input_seq_length: ${policy.max_total_sequence_length}
76+
shuffle: true
7677

7778
logger:
7879
log_dir: "logs"

examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ policy:
106106
data:
107107
dataset_name: "HelpSteer3"
108108
max_input_seq_length: ${policy.max_total_sequence_length}
109+
shuffle: true
109110

110111
logger:
111112
log_dir: "logs"

examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ policy:
106106
data:
107107
dataset_name: "HelpSteer3"
108108
max_input_seq_length: ${policy.max_total_sequence_length}
109+
shuffle: true
109110

110111
logger:
111112
log_dir: "logs"

examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ policy:
7474
data:
7575
dataset_name: "HelpSteer3"
7676
max_input_seq_length: ${policy.max_total_sequence_length}
77+
shuffle: true
78+
7779
logger:
7880
log_dir: "logs"
7981
wandb_enabled: true

0 commit comments

Comments
 (0)