@@ -62,7 +62,6 @@ def __init__(self, job_config: ForgeJobConfig):
6262 self .current_step = 0
6363 self .num_training_steps = job_config .training .steps
6464 self .gradient_accumulation_steps = 1 # Example value, adjust as needed
65- self ._run_val_every_n_steps = job_config .get ("run_val_every_n_steps" , None )
6665 super ().__init__ (job_config )
6766 self .metric_logger = None # TODO: fix this
6867
@@ -74,8 +73,7 @@ def setup(self):
7473
7574 self .val_dataloader = self .setup_data (
7675 self .job_config .dataset_val ,
77- batch_size = self .job_config .training .local_batch_size ,
78- infinite = False ,
76+ batch_size = self .job_config .validation .local_batch_size ,
7977 )
8078
8179 # self.train_dataloader = self.setup_data(
@@ -236,19 +234,22 @@ def train(self) -> None:
236234 )
237235
238236 if (
239- self ._run_val_every_n_steps is not None
240- and self .current_step % self ._run_val_every_n_steps == 0
237+ self .job_config .validation .freq > 0
238+ and self .job_config .validation .steps > 0
239+ and self .current_step % self .job_config .validation .freq == 0
241240 ):
242- self .validate ()
241+ self .validate (self . job_config . validation . steps )
243242
244- def validate (self ) -> None :
243+ def validate (self , max_steps : int ) -> None :
245244 for m in self .model_parts :
246245 m .eval ()
247246 total_val_loss = torch .tensor (0.0 , device = self .device )
248247 total_val_tokens = torch .tensor (0.0 , device = self .device )
249248 with torch .no_grad ():
250249 val_pbar = tqdm (self .val_dataloader , desc = "Validation" , leave = False )
251250 for batch_idx , batch in enumerate (val_pbar ):
251+ if batch_idx >= max_steps :
252+ break
252253 batch_to_device (batch , self .device )
253254 current_num_tokens = (batch ["labels" ] != CROSS_ENTROPY_IGNORE_IDX ).sum ()
254255 # Compute loss
0 commit comments