Skip to content

Commit 58635c8

Browse files
committed
add max_steps for validation to avoid hang
1 parent 91b744c commit 58635c8

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,15 @@ training:
3232
steps: 1000
3333
compile: false
3434

35+
validation:
36+
local_batch_size: 1
37+
freq: -1 # Change to a positive number to enable validation
38+
steps: 200 # Max steps to run validation. Validation disabled if negative.
39+
3540
dataset:
3641
path: yahma/alpaca-cleaned
3742
split: train[:95%]
3843

39-
# Validation
40-
run_val_every_n_steps: null # Change to an integer to enable validation every N steps
4144
dataset_val:
4245
path: yahma/alpaca-cleaned
4346
split: train[95%:]

apps/sft/main.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(self, job_config: ForgeJobConfig):
6262
self.current_step = 0
6363
self.num_training_steps = job_config.training.steps
6464
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
65-
self._run_val_every_n_steps = job_config.get("run_val_every_n_steps", None)
6665
super().__init__(job_config)
6766
self.metric_logger = None # TODO: fix this
6867

@@ -74,8 +73,7 @@ def setup(self):
7473

7574
self.val_dataloader = self.setup_data(
7675
self.job_config.dataset_val,
77-
batch_size=self.job_config.training.local_batch_size,
78-
infinite=False,
76+
batch_size=self.job_config.validation.local_batch_size,
7977
)
8078

8179
# self.train_dataloader = self.setup_data(
@@ -236,19 +234,22 @@ def train(self) -> None:
236234
)
237235

238236
if (
239-
self._run_val_every_n_steps is not None
240-
and self.current_step % self._run_val_every_n_steps == 0
237+
self.job_config.validation.freq > 0
238+
and self.job_config.validation.steps > 0
239+
and self.current_step % self.job_config.validation.freq == 0
241240
):
242-
self.validate()
241+
self.validate(self.job_config.validation.steps)
243242

244-
def validate(self) -> None:
243+
def validate(self, max_steps: int) -> None:
245244
for m in self.model_parts:
246245
m.eval()
247246
total_val_loss = torch.tensor(0.0, device=self.device)
248247
total_val_tokens = torch.tensor(0.0, device=self.device)
249248
with torch.no_grad():
250249
val_pbar = tqdm(self.val_dataloader, desc="Validation", leave=False)
251250
for batch_idx, batch in enumerate(val_pbar):
251+
if batch_idx >= max_steps:
252+
break
252253
batch_to_device(batch, self.device)
253254
current_num_tokens = (batch["labels"] != CROSS_ENTROPY_IGNORE_IDX).sum()
254255
# Compute loss

0 commit comments

Comments
 (0)