@@ -35,21 +35,30 @@ def extract_epoch_from_batch(batch: dict) -> int | None:
3535        Epoch number from metrics, or None if not found 
3636    """ 
3737    if  "metrics"  in  batch :
38+         # Look for num_epochs in metric keys 
39+         for  metric  in  batch ["metrics" ]:
40+             # Metrics have a 'key' attribute with paths like: 
41+             # 'dataset/yahma_alpaca-cleaned_train[:1%]/num_epochs' 
42+             if  hasattr (metric , "key" ) and  "num_epochs"  in  metric .key :
43+                 return  int (metric .value )
44+ 
45+         # Fallback: check for old-style metric_name attribute 
3846        for  metric  in  batch ["metrics" ]:
3947            if  hasattr (metric , "metric_name" ) and  metric .metric_name  ==  "num_epochs" :
40-                 return  metric .value 
48+                 return  int (metric .value )
49+ 
4150    return  None 
4251
4352
4453def  start_epoch_sync (
45-     epoch_increment :  int ,
54+     epoch_changed :  bool ,
4655    device : torch .device ,
4756    dp_process_group : Any  =  None ,
4857) ->  tuple [torch .Tensor  |  None , Any ]:
4958    """Start async all_reduce for epoch synchronization across ranks. 
5059
5160    Args: 
52-         epoch_increment: Difference between current and starting epoch  
61+         epoch_changed: Whether the epoch changed on this rank  
5362        device: Device for tensor 
5463        dp_process_group: Data parallel process group (None = default group) 
5564
@@ -59,7 +68,8 @@ def start_epoch_sync(
5968    if  not  torch .distributed .is_initialized ():
6069        return  None , None 
6170
62-     epoch_tensor  =  torch .tensor ([epoch_increment ], dtype = torch .long , device = device )
71+     # Convert bool to tensor: 1 if epoch changed, 0 otherwise 
72+     epoch_tensor  =  torch .tensor ([int (epoch_changed )], dtype = torch .long , device = device )
6373    pending_work  =  torch .distributed .all_reduce (
6474        epoch_tensor ,
6575        op = torch .distributed .ReduceOp .MAX ,
@@ -117,7 +127,7 @@ def eval_loop(
117127        Tuple of (avg_loss, num_batches) 
118128    """ 
119129    total_loss  =  torch .tensor (0.0 , device = device )
120-     num_batches ,  starting_epoch   =  0 ,  None 
130+     num_batches   =  0 
121131
122132    # Prefetch first batch 
123133    next_batch  =  next (dataloader_iter )
@@ -142,26 +152,41 @@ def eval_loop(
142152
143153            batch  =  next_batch 
144154
145-             # Track starting  epoch 
155+             # Get current batch  epoch 
146156            current_epoch  =  extract_epoch_fn (batch )
147-             if  starting_epoch  is  None :
148-                 starting_epoch  =  current_epoch 
149157
150-             # Prefetch next batch and start async  epoch check  
158+             # Prefetch next batch and check for  epoch change  
151159            try :
152160                next_batch  =  next (dataloader_iter )
153161                next_epoch  =  extract_epoch_fn (next_batch )
154162
155-                 # Only check epochs if both are available 
156-                 if  next_epoch  is  not   None  and  starting_epoch  is  not   None :
157-                     epoch_increment  =  next_epoch  -  starting_epoch 
163+                 # Simple check: did epoch change between consecutive batches? 
164+                 if  next_epoch  is  not   None  and  current_epoch  is  not   None :
165+                     epoch_changed  =  next_epoch  >  current_epoch 
166+ 
167+                     if  epoch_changed :
168+                         logger .info (
169+                             f"[{ dataset_name }  ] Epoch change detected: " 
170+                             f"{ current_epoch }   → { next_epoch }  " 
171+                         )
172+ 
158173                    if  torch .distributed .is_initialized ():
174+                         # All-reduce: if ANY rank's epoch changed, all ranks should stop 
159175                        epoch_tensor , pending_work  =  start_epoch_sync (
160-                             epoch_increment , device , dp_process_group 
176+                             epoch_changed , device , dp_process_group 
161177                        )
162178                    else :
163-                         should_break  =  epoch_increment  >  0 
179+                         # Single process: stop immediately if epoch changed 
180+                         should_break  =  epoch_changed 
181+                 else :
182+                     # No epoch tracking available - rely on eval_steps 
183+                     if  num_batches  ==  0 :
184+                         logger .info (
185+                             f"[{ dataset_name }  ] No epoch tracking available " 
186+                             f"(current={ current_epoch }  , next={ next_epoch }  )" 
187+                         )
164188            except  StopIteration :
189+                 logger .info (f"[{ dataset_name }  ] StopIteration - dataloader exhausted" )
165190                should_break  =  True 
166191
167192            # Process current batch (overlaps with async all_reduce) 
0 commit comments