@@ -178,6 +178,12 @@ def main(
178178 if fabric .device .type == "cuda" :
179179 fabric .print (f"Memory used: { torch .cuda .max_memory_allocated () / 1e9 :.02f} GB" )
180180
181+ # Final evaluation
182+ val_loss = validate (fabric , model , val_dataloader , dataclasses .replace (eval , max_iters = len (val_dataloader )))
183+ metrics = {"val_loss" : val_loss , "val_ppl" : math .exp (val_loss )}
184+ fabric .log_dict (metrics )
185+ fabric .print (f"Final evaluation | val loss: { val_loss .item ():.3f} | val ppl: { math .exp (val_loss ):.3f} " )
186+
181187 # Save the final Adapter checkpoint at the end of training
182188 save_path = out_dir / "final" / "lit_model.pth.adapter_v2"
183189 save_path .parent .mkdir (parents = True , exist_ok = True )
@@ -211,7 +217,7 @@ def fit(
211217 f" { model .max_seq_length } and context length is { model .config .block_size } "
212218 )
213219
214- validate (fabric , model , val_dataloader , tokenizer , dataclasses .replace (eval , max_iters = 2 ), data ) # sanity check
220+ validate (fabric , model , val_dataloader , dataclasses .replace (eval , max_iters = 2 )) # sanity check
215221
216222 train_iterator = CycleIterator (train_dataloader )
217223 throughput = ThroughputMonitor (fabric , window_size = 50 )
@@ -278,7 +284,8 @@ def fit(
278284
279285 if not is_accumulating and step_count % eval .interval == 0 :
280286 t0 = time .perf_counter ()
281- val_loss = validate (fabric , model , val_dataloader , tokenizer , eval , data )
287+ val_loss = validate (fabric , model , val_dataloader , eval )
288+ generate_example (fabric , model , tokenizer , eval , data )
282289 t1 = time .perf_counter () - t0
283290 fabric .print (f"iter { iter_num } : val loss { val_loss .item ():.4f} , val time: { t1 * 1000 :.2f} ms" )
284291 metrics = {"val_loss" : val_loss , "val_ppl" : math .exp (val_loss )}
@@ -295,11 +302,8 @@ def fit(
295302 save_prompt_style (data .prompt_style , checkpoint_file .parent )
296303
297304
298- # the adapter "kv cache" cannot be initialized under `inference_mode`
299305@torch .no_grad ()
300- def validate (
301- fabric : L .Fabric , model : GPT , val_dataloader : DataLoader , tokenizer : Tokenizer , eval : EvalArgs , data : DataModule
302- ) -> torch .Tensor :
306+ def validate (fabric : L .Fabric , model : GPT , val_dataloader : DataLoader , eval : EvalArgs ) -> torch .Tensor :
303307 fabric .print ("Validating ..." )
304308 model .eval ()
305309 losses = torch .zeros (min (len (val_dataloader ), eval .max_iters ))
@@ -311,25 +315,30 @@ def validate(
311315 losses [k ] = chunked_cross_entropy (logits [..., :- 1 , :], targets [..., 1 :], chunk_size = 0 )
312316
313317 val_loss = losses .mean ()
318+ model .train ()
319+ return val_loss
314320
315- # produce an example:
321+
322+ # the adapter "kv cache" cannot be initialized under `inference_mode`
323+ @torch .no_grad ()
324+ def generate_example (fabric : L .Fabric , model : GPT , tokenizer : Tokenizer , eval : EvalArgs , data : DataModule ):
316325 instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
317326 fabric .print (instruction )
318327 prompt = data .prompt_style .apply (instruction )
319328 encoded = tokenizer .encode (prompt , device = fabric .device )
329+ model .eval ()
330+
320331 with fabric .init_tensor ():
321332 # do not set `max_seq_length=max_returned_token` because memory is not a concern here
322333 model .set_kv_cache (batch_size = 1 )
323334 output = generate (
324335 model , encoded , max_returned_tokens = len (encoded ) + eval .max_new_tokens , temperature = 0.8 , eos_id = tokenizer .eos_id
325336 )
326337 model .clear_kv_cache ()
338+ model .train ()
327339 output = tokenizer .decode (output )
328340 fabric .print (output )
329341
330- model .train ()
331- return val_loss
332-
333342
334343def get_lr_scheduler (optimizer , warmup_steps : int , max_steps : int ):
335344 # linear warmup followed by cosine annealing
0 commit comments