@@ -79,9 +79,25 @@ def __init__(self, config: DictConfig):
7979 self ._rank = current_rank ().rank
8080 self ._size = math .prod (current_size ().values ())
8181
82- # Evaluation settings
83- self .eval_interval = job_config .training .get ("eval_interval" , float ("inf" ))
84- self .eval_steps = job_config .training .get ("eval_steps" , 0 )
82+ # Evaluation settings from validation config
83+ validation_config = job_config .get ("validation" , {})
84+ self .validation_enabled = validation_config .get ("enabled" , False )
85+
86+ if self .validation_enabled :
87+ self .eval_interval = validation_config .get ("eval_interval" )
88+ self .eval_steps = validation_config .get ("eval_steps" )
89+
90+ if self .eval_interval is None :
91+ raise ValueError (
92+ "validation.eval_interval is required when validation.enabled is true"
93+ )
94+ if self .eval_steps is None :
95+ raise ValueError (
96+ "validation.eval_steps is required when validation.enabled is true"
97+ )
98+ else :
99+ self .eval_interval = None
100+ self .eval_steps = None
85101
86102 self ._init_dist ()
87103 super ().__init__ (job_config )
@@ -113,23 +129,30 @@ def _init_dist(self):
113129
114130 @endpoint
115131 async def setup (self ):
116- # Setup training data (first 90% of train split)
132+ # Setup training data from config
133+ dataset_config = self .job_config .get ("dataset" )
134+
117135 self .train_dataloader = self .setup_data (
118- dataset_path = "yahma/alpaca-cleaned" , dataset_split = "train[:90%]"
136+ dataset_path = dataset_config .get ("path" ),
137+ dataset_split = dataset_config .get ("split" ),
119138 )
120139
121- # Setup validation data (last 10% of train split)
140+ # Setup validation data from config
141+ dataset_val_config = self .job_config .get ("dataset_val" , {})
122142 self .val_dataloader = self .setup_data (
123- dataset_path = "yahma/alpaca-cleaned" , dataset_split = "train[90%:]"
143+ dataset_path = dataset_val_config .get ("path" , dataset_config .get ("path" )),
144+ dataset_split = dataset_val_config .get ("split" , dataset_config .get ("split" )),
124145 )
125146
126147 # Load checkpoint if resuming
127148 self .checkpointer .load (step = self .current_step )
128149
129- def setup_data (
130- self , dataset_path : str = "yahma/alpaca-cleaned" , dataset_split : str = "train"
131- ):
150+ def setup_data (self , dataset_path : str , dataset_split : str ):
132151 """Setup data with configurable dataset path and split."""
152+ if not dataset_path or not dataset_split :
153+ raise ValueError (
154+ f"dataset.path and dataset.split are required in YAML config. Got path={ dataset_path } , split={ dataset_split } "
155+ )
133156 print (os .path .join (self .job_config .model .hf_assets_path , "tokenizer.json" ))
134157 tokenizer = HuggingFaceModelTokenizer (
135158 tokenizer_json_path = os .path .join (
@@ -281,39 +304,26 @@ def train_step(self, batch) -> None:
281304
282305 def _extract_epoch_from_batch (self , batch : dict ) -> int | None :
283306 """Extract epoch number from batch metrics."""
284- if "metrics" not in batch :
285- return None
286-
287- for metric in batch ["metrics" ]:
288- if hasattr (metric , "metric_name" ) and metric .metric_name == "num_epochs" :
289- return metric .value
307+ if "metrics" in batch :
308+ for metric in batch ["metrics" ]:
309+ if (
310+ hasattr (metric , "metric_name" )
311+ and metric .metric_name == "num_epochs"
312+ ):
313+ return metric .value
290314 return None
291315
292316 async def evaluate (self ) -> dict [str , float ]:
293- """Run evaluation on validation set for one complete epoch.
294-
295- Uses prefetch + non-blocking all_reduce pattern to detect epoch completion
296- across all ranks without blocking on every batch.
297-
298- Pattern:
299- - Iteration N: Start async all_reduce on next batch's epoch (non-blocking)
300- - Process current batch while all_reduce completes in background
301- - Iteration N+1: Check result from previous all_reduce (should be done)
302-
303- This overlaps communication with computation for better performance.
304- """
317+ """Run evaluation with async all_reduce for cross-rank epoch synchronization."""
305318 logger .info ("=" * 50 )
306- logger .info ("STARTING EVALUATION " )
319+ logger .info ("STARTING EVALUATION" )
307320 logger .info ("=" * 50 )
308321
309- # Set model to eval mode
310322 for model_part in self .model_parts :
311323 model_part .eval ()
312324
313325 val_dataloader = iter (self .val_dataloader )
314- total_loss = 0.0
315- num_batches = 0
316- starting_epoch = None
326+ total_loss , num_batches , starting_epoch = 0.0 , 0 , None
317327
318328 # Prefetch first batch
319329 try :
@@ -322,106 +332,79 @@ async def evaluate(self) -> dict[str, float]:
322332 logger .warning ("Validation dataloader is empty" )
323333 return {"val_loss" : 0.0 , "val_batches" : 0 }
324334
325- next_should_break = False
326- pending_work = None # Handle for async all_reduce
327- epoch_tensor = None # Tensor for all_reduce result
335+ should_break , pending_work , epoch_tensor = False , None , None
328336
329337 with torch .no_grad ():
330338 while True :
331- # Check result from PREVIOUS iteration's async all_reduce
339+ # Wait for previous async all_reduce to complete
332340 if pending_work is not None :
333- pending_work .wait () # Should be complete (or very fast) since we did compute
334- if epoch_tensor is not None :
335- next_should_break = epoch_tensor .item () > 0
341+ pending_work .wait ()
342+ should_break = (
343+ epoch_tensor .item () > 0 if epoch_tensor is not None else False
344+ )
336345 pending_work = None
337346
338- # Check if we should break (based on previous iteration's check)
339- if next_should_break :
347+ if should_break :
340348 logger .info (
341349 "Epoch completed across all ranks - stopping evaluation"
342350 )
343351 break
344352
345- # Check optional cap on eval steps
346353 if self .eval_steps > 0 and num_batches >= self .eval_steps :
347354 logger .info (f"Reached eval_steps cap of { self .eval_steps } " )
348355 break
349356
350- # Use the batch that was prefetched in previous iteration
351357 batch = next_batch
352358
353- # Extract epoch from current batch
359+ # Track starting epoch
354360 current_epoch = self ._extract_epoch_from_batch (batch )
355361 if current_epoch is not None and starting_epoch is None :
356362 starting_epoch = current_epoch
357- logger .info (f"Starting evaluation at epoch { starting_epoch } " )
358363
359- # Prefetch next batch and start async all_reduce
364+ # Prefetch next batch and start async epoch check
360365 try :
361366 next_batch = next (val_dataloader )
362-
363- # Extract epoch from next batch
364367 next_epoch = self ._extract_epoch_from_batch (next_batch )
365368
366- # Start NON-BLOCKING all_reduce to check if any rank completed epoch
367369 if next_epoch is not None and starting_epoch is not None :
368- # Check if next batch indicates epoch completion
369370 epoch_increment = next_epoch - starting_epoch
370-
371371 if torch .distributed .is_initialized ():
372- # Create tensor for all_reduce
373372 epoch_tensor = torch .tensor (
374373 [epoch_increment ], dtype = torch .long , device = self .device
375374 )
376- # Start async all_reduce (returns immediately, doesn't block)
377375 pending_work = torch .distributed .all_reduce (
378376 epoch_tensor ,
379377 op = torch .distributed .ReduceOp .MAX ,
380- async_op = True , # NON-BLOCKING - returns immediately
378+ async_op = True ,
381379 )
382380 else :
383- # Single rank case - just check locally
384- next_should_break = epoch_increment > 0
385-
381+ should_break = epoch_increment > 0
386382 except StopIteration :
387- # No more batches - this is the last one
388- next_should_break = True
383+ should_break = True
389384
390- # Process current batch (while all_reduce completes in background)
391- # Move tensors to device
385+ # Process current batch (overlaps with async all_reduce)
392386 for k , v in batch .items ():
393387 if isinstance (v , torch .Tensor ):
394388 batch [k ] = v .to (self .device )
395389
396390 labels = batch .pop ("labels" )
397391 loss = self .forward_only (batch , labels )
398- # GPU compute happens here while network does all_reduce
399-
400392 total_loss += loss .item ()
401393 num_batches += 1
402394
403- eval_steps_info = f"/{ self .eval_steps } " if self .eval_steps > 0 else ""
404- logger .info (
405- f" Eval batch { num_batches } { eval_steps_info } | Loss: { loss .item ():.4f} "
406- )
395+ if num_batches % 10 == 0 :
396+ logger .info (f" Eval batch { num_batches } | Loss: { loss .item ():.4f} " )
407397
408- # Set model back to train mode
409398 for model_part in self .model_parts :
410399 model_part .train ()
411400
412401 avg_loss = total_loss / max (num_batches , 1 )
413-
414- metrics = {
415- "val_loss" : avg_loss ,
416- "val_batches" : num_batches ,
417- }
418-
419- logger .info ("-" * 50 )
420- logger .info (f"EVALUATION COMPLETE" )
421- logger .info (f"Validation Loss: { avg_loss :.4f} " )
422- logger .info (f"Batches Evaluated: { num_batches } " )
402+ logger .info (
403+ f"EVALUATION COMPLETE | Val Loss: { avg_loss :.4f} | Batches: { num_batches } "
404+ )
423405 logger .info ("=" * 50 )
424- return metrics
406+
407+ return {"val_loss" : avg_loss , "val_batches" : num_batches }
425408
426409 @endpoint
427410 async def train (self ) -> None :
@@ -439,8 +422,8 @@ async def train(self) -> None:
439422 self .train_step (batch )
440423 self .current_step += 1
441424
442- # Run evaluation periodically
443- if self .current_step % self .eval_interval == 0 :
425+ # Run evaluation periodically if enabled
426+ if self .validation_enabled and self . current_step % self .eval_interval == 0 :
444427 eval_metrics = await self .evaluate ()
445428 logger .info (f"Step { self .current_step } | Eval metrics: { eval_metrics } " )
446429
0 commit comments