@@ -97,6 +97,7 @@ def __init__(
9797 grad_accum_every : int = 1 ,
9898 valid_dataset : Dataset | None = None ,
9999 valid_every : int = 1000 ,
100+ test_dataset : Dataset | None = None ,
100101 optimizer : Optimizer | None = None ,
101102 scheduler : LRScheduler | None = None ,
102103 ema_decay = 0.999 ,
@@ -159,6 +160,14 @@ def __init__(
159160 self .valid_dataset_size = len (valid_dataset )
160161 self .valid_dataloader = DataLoader (valid_dataset , batch_size = batch_size )
161162
163+ # testing dataloader on EMA model
164+
165+ self .needs_test = exists (test_dataset )
166+
167+ if self .needs_test and self .is_main :
168+ self .test_dataset_size = len (test_dataset )
169+ self .test_dataloader = DataLoader (test_dataset , batch_size = batch_size )
170+
162171 # training steps and num gradient accum steps
163172
164173 self .num_train_steps = num_train_steps
@@ -347,4 +356,35 @@ def __call__(
347356
348357 self .wait ()
349358
359+ # maybe test
360+
361+ if self .is_main and self .needs_test :
362+ with torch .no_grad ():
363+ self .ema_model .eval ()
364+
365+ total_test_loss = 0.
366+ test_loss_breakdown = None
367+
368+ for test_batch in self .test_dataloader :
369+ test_loss , loss_breakdown = self .ema_model (
370+ ** test_batch ,
371+ return_loss_breakdown = True
372+ )
373+
374+ test_batch_size = test_batch .get ('atom_inputs' ).shape [0 ]
375+ scale = test_batch_size / self .test_dataset_size
376+
377+ total_test_loss += test_loss .item () * scale
378+ test_loss_breakdown = accum_dict (test_loss_breakdown , loss_breakdown ._asdict (), scale = scale )
379+
380+ self .print (f'test loss: { total_test_loss :.3f} ' )
381+
382+ # prepend test_ to all losses for logging
383+
384+ test_loss_breakdown = {f'test_{ k } ' :v for k , v in test_loss_breakdown .items ()}
385+
386+ # log
387+
388+ self .log (** test_loss_breakdown )
389+
350390 print (f'training complete' )
0 commit comments