Skip to content

Commit 355aa98

Browse files
authored
perf: Add num_workers in DPO, GRPO and SFT for loading data (#1314)
Signed-off-by: Kate Cheng <[email protected]>
1 parent 4db1704 commit 355aa98

File tree

10 files changed

+14
-2
lines changed

10 files changed

+14
-2
lines changed

examples/configs/dpo.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ policy:
158158
data:
159159
max_input_seq_length: ${policy.max_total_sequence_length}
160160
shuffle: true
161+
num_workers: 1
161162

162163
dataset_name: HelpSteer3
163164
# You can use custom preference datasets for training and validation. For example:

examples/configs/grpo_math_1B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ data:
219219
prompt_file: "examples/prompts/cot.txt"
220220
system_prompt_file: null
221221
shuffle: true
222+
num_workers: 1
222223

223224
dataset_name: "OpenMathInstruct-2"
224225
# You can use custom response datasets for training and validation. For example:

examples/configs/rm.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ policy:
129129
data:
130130
max_input_seq_length: ${policy.max_total_sequence_length}
131131
shuffle: true
132+
num_workers: 1
132133

133134
dataset_name: HelpSteer3
134135
# You can use custom preference datasets for training and validation. For example:

examples/configs/sft_openmathinstruct2_megatron.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ data:
132132
add_eos: true
133133
add_generation_prompt: true
134134
output_key: 'generated_solution'
135+
num_workers: 1
135136

136137
logger:
137138
log_dir: "logs" # Base directory for all logs

examples/configs/vlm_grpo_3B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ data:
203203
dataset_name: "clevr-cogent"
204204
split: "trainA"
205205
shuffle: true
206+
num_workers: 1
206207

207208
env:
208209
clevr-cogent:

examples/configs/vlm_grpo_3B_megatron.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ data:
156156
dataset_name: clevr-cogent
157157
split: trainA
158158
shuffle: true
159+
num_workers: 1
159160
env:
160161
clevr-cogent:
161162
num_workers: 8

nemo_rl/algorithms/dpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def setup(
176176
add_loss_mask=True,
177177
),
178178
drop_last=True,
179+
num_workers=data_config["num_workers"],
179180
)
180181

181182
if last_checkpoint_path is not None:
@@ -198,6 +199,7 @@ def setup(
198199
add_loss_mask=True,
199200
),
200201
drop_last=False,
202+
num_workers=data_config["num_workers"],
201203
)
202204
for k, v in val_dataset.items()
203205
}

nemo_rl/algorithms/grpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def setup(
207207
shuffle=data_config["shuffle"],
208208
collate_fn=rl_collate_fn,
209209
drop_last=True,
210+
num_workers=data_config["num_workers"],
210211
)
211212
if last_checkpoint_path is not None:
212213
dataloader_state_dict = torch.load(
@@ -228,6 +229,7 @@ def setup(
228229
batch_size=grpo_config["val_batch_size"],
229230
shuffle=False,
230231
collate_fn=rl_collate_fn,
232+
num_workers=data_config["num_workers"],
231233
)
232234
print(
233235
f" ✓ Validation dataloader loaded with {len(val_dataset)} samples",

nemo_rl/algorithms/rm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def setup(
151151
add_loss_mask=False,
152152
),
153153
drop_last=True,
154+
num_workers=data_config["num_workers"],
154155
)
155156

156157
if last_checkpoint_path is not None:
@@ -173,6 +174,7 @@ def setup(
173174
add_loss_mask=False,
174175
),
175176
drop_last=False,
177+
num_workers=data_config["num_workers"],
176178
)
177179
for k, v in val_dataset.items()
178180
}

tests/unit/algorithms/test_grpo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_single_node():
240240
"val_period": 0,
241241
"val_at_start": False,
242242
},
243-
"data": {"shuffle": False},
243+
"data": {"shuffle": False, "num_workers": 1},
244244
"logger": {}, # Config extraction requires this key
245245
"checkpointing": {}, # Config extraction requires this key
246246
"cluster": {
@@ -296,7 +296,7 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node():
296296
"val_period": 0,
297297
"val_at_start": False,
298298
},
299-
"data": {"shuffle": False},
299+
"data": {"shuffle": False, "num_workers": 1},
300300
"logger": {}, # Config extraction requires this key
301301
"checkpointing": {}, # Config extraction requires this key
302302
"cluster": {

0 commit comments

Comments
 (0)