Skip to content

Commit cde2acd

Browse files
authored
perf: Add a field in SFT data config to modify num_workers for loading data (#1143)
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
1 parent 42aa41b commit cde2acd

File tree

3 files changed

+8
-0
lines changed

3 files changed

+8
-0
lines changed

examples/configs/sft.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ data:
136136
add_eos: true
137137
add_generation_prompt: false
138138
shuffle: true
139+
num_workers: 1
139140

140141
dataset_name: "squad"
141142
# You can use custom response datasets for training and validation. For example:

nemo_rl/algorithms/sft.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def setup(
138138
shuffle=data_config["shuffle"],
139139
collate_fn=rl_collate_fn,
140140
drop_last=True,
141+
num_workers=data_config["num_workers"],
141142
)
142143

143144
if last_checkpoint_path is not None:
@@ -152,6 +153,7 @@ def setup(
152153
shuffle=False,
153154
collate_fn=rl_collate_fn,
154155
drop_last=False,
156+
num_workers=data_config["num_workers"],
155157
)
156158

157159
# ==========================

nemo_rl/data/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ class DataConfig(TypedDict):
3333
download_dir: NotRequired[str]
3434
train_data_path: NotRequired[str]
3535
val_data_paths: NotRequired[dict[str, str]]
36+
# Number of data loader workers.
37+
# Set to 8 or 10 for large batches to improve loading speed.
38+
# This saturates CPU threads without consuming too much memory
39+
# However, setting it too high might cause memory issues for long seqlens.
40+
num_workers: NotRequired[int]
3641

3742

3843
class MathDataConfig(DataConfig):

0 commit comments

Comments
 (0)