1010This is a concrete implementation of BaseForgeActor for supervised fine-tuning.
1111"""
1212
13+ import contextlib
1314import logging
1415
1516import torch
1920 create_context_parallel_context ,
2021 log_training_step ,
2122 move_batch_to_device ,
23+ setup_eval_dataloaders ,
2224 setup_sft_dataloader ,
2325 setup_tokenizer ,
2426)
27+ from forge .data .utils import StopAfterOneEpoch
2528from monarch .actor import endpoint
2629from omegaconf import DictConfig
2730
@@ -34,19 +37,16 @@ class TrainerActor(BaseForgeActor):
3437 Concrete trainer actor for supervised fine-tuning.
3538
3639 Handles training loop, forward/backward passes, and checkpoint management.
40+
41+ Args:
42+ config: Configuration dictionary containing training settings
3743 """
3844
3945 train_spec : forge_train_spec .ForgeTrainSpec
4046 train_dataloader : any
4147 num_training_steps : int
4248
4349 def __init__ (self , config : DictConfig ):
44- """
45- Initialize the trainer actor.
46-
47- Args:
48- config: Configuration dictionary containing training settings
49- """
5050 super ().__init__ (config )
5151 self .num_training_steps = self .job_config .training .steps
5252
@@ -61,6 +61,7 @@ async def setup(self):
6161 hf_assets_path = self .job_config .model .hf_assets_path
6262 )
6363
64+ # Setup training dataloader
6465 self .train_dataloader = setup_sft_dataloader (
6566 tokenizer = self .tokenizer ,
6667 dataset_path = "yahma/alpaca-cleaned" ,
@@ -70,6 +71,31 @@ async def setup(self):
7071 device = self .device ,
7172 )
7273
74+ # Setup evaluation dataloaders if configured
75+ eval_config = self .job_config .get ("eval" , {})
76+ self .val_dataloaders = {}
77+ self .eval_every_n_steps = eval_config .get ("eval_every_n_steps" )
78+ max_eval_steps = eval_config .get ("max_eval_steps" )
79+ self .max_eval_steps = (
80+ max_eval_steps if max_eval_steps and max_eval_steps > 0 else None
81+ )
82+ self .validation_enabled = (
83+ self .eval_every_n_steps is not None and self .eval_every_n_steps > 0
84+ )
85+
86+ if self .validation_enabled :
87+ logger .info ("Setting up eval datasets..." )
88+ eval_datasets_config = eval_config .get ("datasets" , [])
89+ self .val_dataloaders = setup_eval_dataloaders (
90+ tokenizer = self .tokenizer ,
91+ eval_datasets_config = eval_datasets_config ,
92+ target_tokens_per_pack = self .job_config .training .seq_len ,
93+ batch_size = self .job_config .training .local_batch_size ,
94+ device = self .device ,
95+ )
96+ logger .info (f"Loaded { len (self .val_dataloaders )} eval datasets" )
97+
98+ # Load checkpoint if exists
7399 if self .checkpointer :
74100 logger .info ("Loading checkpoint..." )
75101 self .checkpointer .load (step = self .current_step )
@@ -163,14 +189,179 @@ async def run(self) -> None:
163189 self .train_step (batch )
164190 self .current_step += 1
165191
192+ # Run evaluation periodically if enabled
193+ if (
194+ self .validation_enabled
195+ and self .current_step % self .eval_every_n_steps == 0
196+ ):
197+ await self .evaluate ()
198+
166199 if self .checkpointer :
167200 self .checkpointer .save (
168201 curr_step = self .current_step ,
169202 last_step = self .current_step == self .num_training_steps ,
170203 )
171204
205+ # Final evaluation
206+ if self .validation_enabled :
207+ logger .info ("Running final evaluation at end of training..." )
208+ await self .evaluate ()
209+
172210 logger .info ("Training complete!" )
173211
212+ async def evaluate (self ) -> None :
213+ """
214+ Run evaluation on multiple datasets, one at a time.
215+
216+ 1. Set models to eval mode
217+ 2. For each eval dataset:
218+ - Create fresh iterator (starts from epoch 0)
219+ - Use StopAfterOneEpoch to iterate until epoch boundary
220+ - Respect max_eval_steps cap if configured
221+ - Record loss and step metrics
222+ 3. Restore models to train mode
223+ """
224+ logger .info ("==Starting evaluation==" )
225+
226+ # Set models to eval mode
227+ for model_part in self .model_parts :
228+ model_part .eval ()
229+
230+ # Get DP mesh for epoch synchronization
231+ dp_mesh = None
232+ if self .parallel_dims is not None and self .parallel_dims .dp_enabled :
233+ dp_mesh = self .parallel_dims .world_mesh .get_group ("dp" )
234+
235+ # For non-PP: disable gradients to save memory
236+ maybe_no_grad = (
237+ contextlib .nullcontext ()
238+ if self .parallel_dims .pp_enabled
239+ else torch .no_grad ()
240+ )
241+
242+ # Evaluate each dataset sequentially
243+ all_dataset_losses = []
244+ all_dataset_steps = []
245+
246+ for dataset_name , val_dataloader in self .val_dataloaders .items ():
247+ logger .info (f"=====Evaluating dataset: { dataset_name } =====" )
248+
249+ total_loss = torch .tensor (0.0 , device = self .device )
250+ num_steps = 0
251+
252+ # NOTE: Assumes batch contains field "metrics" containing "num_epochs"
253+ batch_iter = StopAfterOneEpoch (
254+ iter = iter (val_dataloader ), # Fresh iterator from epoch 0
255+ device = self .device ,
256+ dp_mesh = dp_mesh ,
257+ )
258+
259+ with maybe_no_grad :
260+ for batch in batch_iter :
261+ # If max_eval_steps>len(dataset), it will be stopped earlier
262+ if (
263+ self .max_eval_steps is not None
264+ and num_steps >= self .max_eval_steps
265+ ):
266+ logger .info (
267+ f"[{ dataset_name } ] Reached max_eval_steps cap of { self .max_eval_steps } "
268+ )
269+ break
270+
271+ # Move batch to device
272+ batch = move_batch_to_device (batch , self .device )
273+
274+ # Forward pass only (no backward)
275+ labels = batch .pop ("labels" )
276+ loss = self .forward_backward_eval (batch , labels )
277+ total_loss += loss
278+ num_steps += 1
279+
280+ logger .info (
281+ f"[dataset { dataset_name } ] Step { num_steps } | Loss: { loss .item ():.4f} "
282+ )
283+
284+ # Log average loss for this dataset
285+ avg_loss = (total_loss / max (num_steps , 1 )).item ()
286+ all_dataset_losses .append (avg_loss )
287+ all_dataset_steps .append (num_steps )
288+ logger .info (
289+ f"[dataset { dataset_name } ] Final Step { num_steps } | Avg Loss: { avg_loss :.4f} "
290+ )
291+
292+ # Record macro and micro average losses across datasets
293+ if len (all_dataset_losses ) > 1 :
294+ # Macro: same weight for all datasets
295+ macro_avg_loss = sum (all_dataset_losses ) / len (all_dataset_losses )
296+ logger .info (f"Macro avg loss (unweighted): { macro_avg_loss :.4f} " )
297+
298+ # Micro: weighted mean by dataset size
299+ total_steps = sum (all_dataset_steps )
300+ micro_avg_loss = (
301+ sum (
302+ loss * steps
303+ for loss , steps in zip (all_dataset_losses , all_dataset_steps )
304+ )
305+ / total_steps
306+ )
307+ logger .info (f"Micro avg loss (weighted): { micro_avg_loss :.4f} " )
308+
309+ # Restore train mode
310+ for model_part in self .model_parts :
311+ model_part .train ()
312+
313+ logger .info ("==Evaluation complete==" )
314+
315+ def forward_backward_eval (
316+ self , input_dict : dict [str , torch .Tensor ], labels : torch .Tensor
317+ ) -> torch .Tensor :
318+ """
319+ Perform forward pass only (for evaluation).
320+
321+ Args:
322+ input_dict: Dictionary containing input tokens
323+ labels: Ground truth labels
324+
325+ Returns:
326+ Computed loss value
327+ """
328+ model_parts = self .model_parts
329+ parallel_dims = self .parallel_dims
330+ inputs = input_dict ["tokens" ]
331+
332+ optional_context_parallel_ctx = create_context_parallel_context (
333+ parallel_dims = parallel_dims ,
334+ inputs = inputs ,
335+ labels = labels ,
336+ model_parts = model_parts ,
337+ rotate_method = self .job_config .parallelism .context_parallel_rotate_method ,
338+ )
339+
340+ if parallel_dims .pp_enabled :
341+ with self .train_context (optional_context_parallel_ctx ):
342+ targets , losses = (
343+ (labels , []) if self .pp_has_last_stage else (None , None )
344+ )
345+ if self .pp_has_first_stage :
346+ self .pp_schedule .step (inputs , target = targets , losses = losses )
347+ else :
348+ self .pp_schedule .step (target = targets , losses = losses )
349+
350+ loss = (
351+ torch .sum (torch .stack (losses )).to (self .device )
352+ if self .pp_has_last_stage
353+ else torch .tensor (- 1.0 , device = self .device )
354+ )
355+ else :
356+ with self .train_context (optional_context_parallel_ctx ):
357+ assert len (model_parts ) == 1
358+ with self .maybe_enable_amp :
359+ pred = model_parts [0 ](inputs )
360+ loss = self .loss_fn (pred , labels )
361+ del pred
362+
363+ return loss
364+
174365 @endpoint
175366 async def cleanup (self ) -> None :
176367 """
0 commit comments