@@ -52,6 +52,20 @@ def cycle(dataloader: DataLoader):
5252 for batch in dataloader :
5353 yield batch
5454
55+ @typecheck
56+ def accum_dict (
57+ past_losses : dict | None ,
58+ losses : dict ,
59+ scale : float = 1.
60+ ):
61+ if not exists (past_losses ):
62+ return losses
63+
64+ for loss_name in past_losses .keys ():
65+ past_losses [loss_name ] += losses .get (loss_name , 0. ) * scale
66+
67+ return past_losses
68+
5569def default_lambda_lr_fn (steps ):
5670 # 1000 step warmup
5771
@@ -193,6 +207,9 @@ def __call__(
193207
194208 # gradient accumulation
195209
210+ total_loss = 0.
211+ train_loss_breakdown = None
212+
196213 for grad_accum_step in range (self .grad_accum_every ):
197214 is_accumulating = grad_accum_step < (self .grad_accum_every - 1 )
198215
@@ -207,15 +224,22 @@ def __call__(
207224 return_loss_breakdown = True
208225 )
209226
227+ # accumulate
228+
229+ scale = self .grad_accum_every ** - 1
230+
231+ total_loss += loss .item () * scale
232+ train_loss_breakdown = accum_dict (train_loss_breakdown , loss_breakdown ._asdict (), scale = scale )
233+
210234 # backwards
211235
212236 self .fabric .backward (loss / self .grad_accum_every )
213237
214238 # log entire loss breakdown
215239
216- self .log (** loss_breakdown . _asdict () )
240+ self .log (** train_loss_breakdown )
217241
218- self .print (f'loss: { loss . item () :.3f} ' )
242+ self .print (f'loss: { total_loss :.3f} ' )
219243
220244 # clip gradients
221245
@@ -252,21 +276,30 @@ def __call__(
252276 self .ema_model .eval ()
253277
254278 total_valid_loss = 0.
279+ valid_loss_breakdown = None
255280
256281 for valid_batch in self .valid_dataloader :
257- valid_loss , valid_loss_breakdown = self .ema_model (
282+ valid_loss , loss_breakdown = self .ema_model (
258283 ** valid_batch ,
259284 return_loss_breakdown = True
260285 )
261286
262287 valid_batch_size = valid_batch .get ('atom_inputs' ).shape [0 ]
263288 scale = valid_batch_size / self .valid_dataset_size
264289
265- scaled_valid_loss = valid_loss .item () * scale
266- total_valid_loss += scaled_valid_loss
290+ total_valid_loss + = valid_loss .item () * scale
291+ valid_loss_breakdown = accum_dict ( valid_loss_breakdown , loss_breakdown . _asdict (), scale = scale )
267292
268293 self .print (f'valid loss: { total_valid_loss :.3f} ' )
269294
295+ # prepend valid_ to all losses for logging
296+
297+ valid_loss_breakdown = {f'valid_{ k } ' :v for k , v in valid_loss_breakdown .items ()}
298+
299+ # log
300+
301+ self .log (** valid_loss_breakdown )
302+
270303 self .wait ()
271304
272305 print (f'training complete' )
0 commit comments