77"""To run:
88
99python -m apps.sft.main --config apps/sft/llama3_8b.yaml
10-
1110"""
1211
1312import asyncio
4039from torchtitan .experiments .forge .engine import ForgeEngine
4140from torchtitan .experiments .forge .job_config import ForgeJobConfig
4241
43- # from tqdm import tqdm
44-
4542# stubs for now
4643Checkpointer = Any
4744Dataloader = Any
@@ -64,7 +61,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine):
6461 checkpointer : Checkpointer
6562 tokenizer : Tokenizer
6663 train_dataloader : Dataloader
67- # val_dataloader: Dataloader
64+ val_dataloader : Dataloader
6865 metric_logger : MetricLogger
6966 profiler : Profiler
7067 device : torch .device
@@ -81,6 +78,11 @@ def __init__(self, config: DictConfig):
8178 self .gradient_accumulation_steps = 1 # Example value, adjust as needed
8279 self ._rank = current_rank ().rank
8380 self ._size = math .prod (current_size ().values ())
81+
82+ # Evaluation settings
83+ self .eval_interval = job_config .training .get ("eval_interval" , float ("inf" ))
84+ self .eval_steps = job_config .training .get ("eval_steps" , 0 )
85+
8486 self ._init_dist ()
8587 super ().__init__ (job_config )
8688
@@ -111,25 +113,23 @@ def _init_dist(self):
111113
112114 @endpoint
113115 async def setup (self ):
114- self .train_dataloader = self .setup_data ()
115- # self.train_dataloader = self.setup_data(
116- # self.train_config.train_dataset_config,
117- # self.train_config.train_dataloader_config,
118- # self.train_config.packing_config,
119- # )
120- # self.val_dataloader = self.setup_data(
121- # self.train_config.val_dataset_config,
122- # self.train_config.val_dataloader_config,
123- # self.train_config.packing_config,
124- # )
125-
126- # TODO: confirm that this is working properly
127- # Should also use load, not dcp_load
116+ # Setup training data (first 90% of train split)
117+ self .train_dataloader = self .setup_data (
118+ dataset_path = "yahma/alpaca-cleaned" , dataset_split = "train[:90%]"
119+ )
120+
121+ # Setup validation data (last 10% of train split)
122+ self .val_dataloader = self .setup_data (
123+ dataset_path = "yahma/alpaca-cleaned" , dataset_split = "train[90%:]"
124+ )
125+
126+ # Load checkpoint if resuming
128127 self .checkpointer .load (step = self .current_step )
129- # self.profiler = self.setup_profiler(self.train_config.profiler_config)
130- # self.logger = self.setup_logger(self.train_config.logger_config)
131128
132- def setup_data (self ):
129+ def setup_data (
130+ self , dataset_path : str = "yahma/alpaca-cleaned" , dataset_split : str = "train"
131+ ):
132+ """Setup data with configurable dataset path and split."""
133133 print (os .path .join (self .job_config .model .hf_assets_path , "tokenizer.json" ))
134134 tokenizer = HuggingFaceModelTokenizer (
135135 tokenizer_json_path = os .path .join (
@@ -146,8 +146,8 @@ def setup_data(self):
146146 dataset = sft_iterable_dataset (
147147 model_transform = tokenizer ,
148148 message_transform = AlpacaToMessages (),
149- path = "yahma/alpaca-cleaned" ,
150- split = "train" ,
149+ path = dataset_path ,
150+ split = dataset_split ,
151151 )
152152 packer = TextPacker (padding_idx = 0 )
153153 dataset = PackedDataset (
@@ -163,10 +163,6 @@ def setup_data(self):
163163 ),
164164 )
165165
166- # Ultimately we probably want something like this
167- # packer = build_packing_strategy(packing_config)
168- # dataset = build_dataset(dataset_config)
169- # dataloader = build_dataloader(dataloader_config, dataset, packer)
170166 return dataloader
171167
172168 def forward_backward (
@@ -206,7 +202,6 @@ def forward_backward(
206202 )
207203
208204 # accumulate losses across pipeline microbatches
209- # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
210205 loss = (
211206 torch .mean (torch .stack (losses )).to (self .device )
212207 if self .pp_has_last_stage
@@ -225,27 +220,125 @@ def forward_backward(
225220
226221 return loss
227222
223+ def forward_only (
224+ self , input_dict : dict [str , torch .Tensor ], labels : torch .Tensor
225+ ) -> torch .Tensor :
226+ """Forward pass only for evaluation (no backward)."""
227+ model_parts = self .model_parts
228+ parallel_dims = self .parallel_dims
229+
230+ inputs = input_dict ["tokens" ]
231+ optional_context_parallel_ctx = (
232+ dist_utils .create_context_parallel_ctx (
233+ cp_mesh = parallel_dims .world_mesh ["cp" ],
234+ cp_buffers = [inputs , labels ] + [m .freqs_cis for m in model_parts ],
235+ cp_seq_dims = [1 , 1 ] + [0 for _ in model_parts ],
236+ cp_no_restore_buffers = {inputs , labels },
237+ cp_rotate_method = self .job_config .parallelism .context_parallel_rotate_method ,
238+ )
239+ if parallel_dims .cp_enabled
240+ else None
241+ )
242+
243+ if parallel_dims .pp_enabled :
244+ # Pipeline Parallel forward only
245+ with self .train_context (optional_context_parallel_ctx ):
246+ targets , losses = (
247+ (labels , []) if self .pp_has_last_stage else (None , None )
248+ )
249+ if self .pp_has_first_stage :
250+ self .pp_schedule .step (
251+ inputs , target = targets , losses = losses , input_batch = inputs
252+ )
253+ else :
254+ self .pp_schedule .step (
255+ target = targets , losses = losses , input_batch = inputs
256+ )
257+
258+ loss = (
259+ torch .mean (torch .stack (losses )).to (self .device )
260+ if self .pp_has_last_stage
261+ else torch .tensor ([- 1.0 ], device = self .device )
262+ )
263+ else :
264+ # Non-PP forward only
265+ with self .train_context (optional_context_parallel_ctx ):
266+ assert len (model_parts ) == 1
267+ with self .maybe_enable_amp :
268+ pred = model_parts [0 ](inputs )
269+ loss = self .loss_fn (pred , labels )
270+ del pred
271+
272+ return loss
273+
228274 def train_step (self , batch ) -> None :
229- # TODO
230- # with GradientAccumulation(
231- # self.gradient_accumulation_steps,
232- # self.model,
233- # self.data_parallel_size,
234- # ) as grad_acc:
235275 labels = batch .pop ("labels" )
236276 loss = self .forward_backward (batch , labels )
237277
238278 logger .info (f"{ self .current_step } / { self .num_training_steps } |Loss: { loss } " )
239- # self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
240- # self.pbar.update(1)
241279 self .optimizers .step ()
242280 self .lr_schedulers .step ()
243281
282+ async def evaluate (self ) -> dict [str , float ]:
283+ """Run evaluation on validation set (internal method, not an endpoint)."""
284+ logger .info ("=" * 50 )
285+ logger .info ("STARTING EVALUATION " )
286+ logger .info ("=" * 50 )
287+
288+ # Set model to eval mode
289+ for model_part in self .model_parts :
290+ model_part .eval ()
291+
292+ val_dataloader = iter (self .val_dataloader )
293+ total_loss = 0.0
294+ num_batches = 0
295+
296+ with torch .no_grad ():
297+ for step in range (self .eval_steps ):
298+ try :
299+ batch = next (val_dataloader )
300+
301+ # Move tensors to device
302+ for k , v in batch .items ():
303+ if isinstance (v , torch .Tensor ):
304+ batch [k ] = v .to (self .device )
305+
306+ labels = batch .pop ("labels" )
307+ loss = self .forward_only (batch , labels )
308+
309+ total_loss += loss .item ()
310+ num_batches += 1
311+
312+ logger .info (
313+ f" Eval batch { num_batches } /{ self .eval_steps } | Loss: { loss .item ():.4f} "
314+ )
315+
316+ except StopIteration :
317+ logger .warning ("Reached end of validation dataloader early" )
318+ break
319+
320+ # Set model back to train mode
321+ for model_part in self .model_parts :
322+ model_part .train ()
323+
324+ avg_loss = total_loss / max (num_batches , 1 )
325+
326+ metrics = {
327+ "val_loss" : avg_loss ,
328+ "val_batches" : num_batches ,
329+ }
330+
331+ logger .info ("-" * 50 )
332+ logger .info (f"EVALUATION COMPLETE" )
333+ logger .info (f"Validation Loss: { avg_loss :.4f} " )
334+ logger .info (f"Batches Evaluated: { num_batches } " )
335+ logger .info ("=" * 50 )
336+ return metrics
337+
244338 @endpoint
245339 async def train (self ) -> None :
246340 dataloader = iter (self .train_dataloader )
247341 self .optimizers .zero_grad ()
248-
249342 # TODO: tqdm is broken in Monarch actors
250343 # self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps)
251344
@@ -254,18 +347,21 @@ async def train(self) -> None:
254347 # Move tensors to the appropriate device
255348 for k , v in batch .items ():
256349 if isinstance (v , torch .Tensor ):
257- batch [k ] = v .to ("cuda" ) # TODO: hardcoded for now
350+ batch [k ] = v .to (self . device ) # TODO: hardcoded for now
258351 self .train_step (batch )
259- # self.profiler.step()
260352 self .current_step += 1
261353
354+ # Run evaluation periodically
355+ if self .current_step % self .eval_interval == 0 :
356+ eval_metrics = await self .evaluate ()
357+ logger .info (f"Step { self .current_step } | Eval metrics: { eval_metrics } " )
358+
359+ # Save checkpoints
262360 self .checkpointer .save (
263361 curr_step = self .current_step ,
264362 last_step = self .current_step == self .num_training_steps ,
265363 )
266364
267- # self.pbar.close()
268-
269365 @endpoint
270366 async def cleanup (self ) -> None :
271367 if self .checkpointer :
0 commit comments