Skip to content

Commit 4794ca7

Browse files
authored
fix: Handle disabled validation in SFT training (#1611)
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
1 parent 48dbb37 commit 4794ca7

File tree

3 files changed

+87
-32
lines changed

3 files changed

+87
-32
lines changed

examples/run_sft.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int):
107107
val_dataset = data.formatted_ds["validation"]
108108
sft_task_spec = data.task_spec
109109
print(
110-
f" ✓ Training and validation datasets loaded with {len(train_dataset)} and {len(val_dataset)} samples, respectively."
110+
f" ✓ Training and validation datasets loaded with {len(train_dataset)} and {len(val_dataset) if val_dataset else 0} samples, respectively."
111111
)
112112

113113
# add preprocessor if needed
@@ -133,19 +133,20 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int):
133133
max_seq_length=data_config["max_input_seq_length"],
134134
)
135135

136-
val_dataset = AllTaskProcessedDataset(
137-
val_dataset,
138-
tokenizer,
139-
sft_task_spec,
140-
partial(
141-
sft_preprocessor,
142-
add_bos=data_config.get("add_bos", True),
143-
add_eos=data_config.get("add_eos", True),
144-
add_generation_prompt=data_config["add_generation_prompt"],
145-
datum_preprocessor=datum_preprocessor,
146-
),
147-
max_seq_length=data_config["max_input_seq_length"],
148-
)
136+
if val_dataset is not None:
137+
val_dataset = AllTaskProcessedDataset(
138+
val_dataset,
139+
tokenizer,
140+
sft_task_spec,
141+
partial(
142+
sft_preprocessor,
143+
add_bos=data_config.get("add_bos", True),
144+
add_eos=data_config.get("add_eos", True),
145+
add_generation_prompt=data_config["add_generation_prompt"],
146+
datum_preprocessor=datum_preprocessor,
147+
),
148+
max_seq_length=data_config["max_input_seq_length"],
149+
)
149150

150151
return train_dataset, val_dataset, sft_task_spec
151152

nemo_rl/algorithms/sft.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ def setup(
9090
master_config: MasterConfig,
9191
tokenizer: AutoTokenizer,
9292
train_dataset: AllTaskProcessedDataset,
93-
val_dataset: AllTaskProcessedDataset,
93+
val_dataset: Optional[AllTaskProcessedDataset],
9494
) -> tuple[
9595
Policy,
9696
RayVirtualCluster,
9797
StatefulDataLoader,
98-
StatefulDataLoader,
98+
Optional[StatefulDataLoader],
9999
NLLLoss,
100100
Logger,
101101
CheckpointManager,
@@ -149,14 +149,17 @@ def setup(
149149
)
150150
train_dataloader.load_state_dict(dataloader_state_dict)
151151

152-
val_dataloader = StatefulDataLoader(
153-
val_dataset,
154-
batch_size=sft_config["val_global_batch_size"],
155-
shuffle=False,
156-
collate_fn=rl_collate_fn,
157-
drop_last=False,
158-
num_workers=data_config["num_workers"],
159-
)
152+
if val_dataset is not None:
153+
val_dataloader = StatefulDataLoader(
154+
val_dataset,
155+
batch_size=sft_config["val_global_batch_size"],
156+
shuffle=False,
157+
collate_fn=rl_collate_fn,
158+
drop_last=False,
159+
num_workers=data_config["num_workers"],
160+
)
161+
else:
162+
val_dataloader = None
160163

161164
# ==========================
162165
# Cluster
@@ -230,7 +233,7 @@ def setup(
230233
# =======================================================
231234
def validate(
232235
policy: PolicyInterface,
233-
val_dataloader: StatefulDataLoader,
236+
val_dataloader: Optional[StatefulDataLoader],
234237
tokenizer,
235238
loss_fn,
236239
step: int,
@@ -242,11 +245,11 @@ def validate(
242245
):
243246
"""Run validation on the validation dataset."""
244247
if val_dataloader is None:
245-
assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
246-
"val_dataloader is None, so dpo.val_period must be 0"
248+
assert master_config["sft"]["val_period"] <= 0, (
249+
"val_dataloader is None, so sft.val_period must be <= 0"
247250
)
248251
print(" ⚠️ No validation dataloader provided, skipping validation")
249-
return
252+
return {}, {}
250253

251254
timer = Timer()
252255

@@ -496,7 +499,7 @@ def sft_train(
496499
metrics[k] = np.mean(v).item()
497500
else:
498501
metrics[k] = np.sum(v).item()
499-
total_valid_tokens += metrics["global_valid_toks"]
502+
total_valid_tokens += metrics.get("global_valid_toks", 0)
500503

501504
## Checkpointing
502505
sft_save_state["consumed_samples"] += master_config["policy"][
@@ -610,9 +613,12 @@ def sft_train(
610613
master_config["cluster"]["num_nodes"]
611614
* master_config["cluster"]["gpus_per_node"]
612615
)
613-
timing_metrics["valid_tokens_per_sec_per_gpu"] = (
614-
metrics["global_valid_toks"] / total_time / total_num_gpus
615-
)
616+
if total_time > 0:
617+
timing_metrics["valid_tokens_per_sec_per_gpu"] = (
618+
metrics.get("global_valid_toks", 0) / total_time / total_num_gpus
619+
)
620+
else:
621+
timing_metrics["valid_tokens_per_sec_per_gpu"] = 0.0
616622
logger.log_metrics(metrics, total_steps + 1, prefix="train")
617623
logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train")
618624

tests/unit/algorithms/test_sft.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,51 @@ def test_exit_on_timeout(mock_components, capsys):
205205
assert "Epoch" not in line or "Epoch 1/10" in line, (
206206
f"Training continued to next epoch after timeout: {line}"
207207
)
208+
209+
210+
def test_training_with_disabled_validation(mock_components):
211+
"""Test that training works when validation is disabled (val_dataloader=None, val_period<=0)"""
212+
mock_components["master_config"]["sft"]["val_period"] = 0
213+
mock_components["master_config"]["sft"]["max_num_steps"] = 5
214+
mock_components["master_config"]["sft"]["max_num_epochs"] = 1
215+
216+
sft_save_state = _default_sft_save_state()
217+
218+
sft_train(
219+
mock_components["policy"],
220+
mock_components["train_dataloader"],
221+
None, # val_dataloader is None
222+
mock_components["tokenizer"],
223+
mock_components["loss_fn"],
224+
mock_components["master_config"],
225+
mock_components["logger"],
226+
mock_components["sft_task_spec"],
227+
mock_components["checkpointer"],
228+
sft_save_state,
229+
)
230+
231+
assert mock_components["policy"].train.call_count == 5
232+
233+
234+
def test_training_with_negative_val_period(mock_components):
235+
"""Test that training works when val_period is negative (validation disabled)"""
236+
mock_components["master_config"]["sft"]["val_period"] = -1
237+
mock_components["master_config"]["sft"]["max_num_steps"] = 3
238+
mock_components["master_config"]["sft"]["max_num_epochs"] = 1
239+
240+
sft_save_state = _default_sft_save_state()
241+
242+
sft_train(
243+
mock_components["policy"],
244+
mock_components["train_dataloader"],
245+
None, # val_dataloader is None
246+
mock_components["tokenizer"],
247+
mock_components["loss_fn"],
248+
mock_components["master_config"],
249+
mock_components["logger"],
250+
mock_components["sft_task_spec"],
251+
mock_components["checkpointer"],
252+
sft_save_state,
253+
)
254+
255+
assert mock_components["policy"].train.call_count == 3

0 commit comments

Comments
 (0)