@@ -44,6 +44,9 @@ def exists(val):
4444def default (v , d ):
4545 return v if exists (v ) else d
4646
47+ def divisible_by (num , den ):
48+ return (num % den ) == 0
49+
4750def cycle (dataloader : DataLoader ):
4851 while True :
4952 for batch in dataloader :
@@ -74,6 +77,8 @@ def __init__(
7477 num_train_steps : int ,
7578 batch_size : int ,
7679 grad_accum_every : int = 1 ,
80+ valid_dataset : Dataset | None = None ,
81+ valid_every : int = 1000 ,
7782 optimizer : Optimizer | None = None ,
7883 scheduler : LRScheduler | None = None ,
7984 ema_decay = 0.999 ,
@@ -122,10 +127,22 @@ def __init__(
122127
123128 self .optimizer = optimizer
124129
125- # data
130+ # train dataloader
126131
127132 self .dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = True , drop_last = True )
128133
134+ # validation dataloader on the EMA model
135+
136+ self .valid_every = valid_every
137+
138+ self .needs_valid = exists (valid_dataset )
139+
140+ if self .needs_valid and self .is_main :
141+ self .valid_dataset_size = len (valid_dataset )
142+ self .valid_dataloader = DataLoader (valid_dataset , batch_size = batch_size )
143+
144+ # training steps and num gradient accum steps
145+
129146 self .num_train_steps = num_train_steps
130147 self .grad_accum_every = grad_accum_every
131148
@@ -154,6 +171,9 @@ def __init__(
154171 def is_main (self ):
155172 return self .fabric .global_rank == 0
156173
174+ def wait (self ):
175+ self .fabric .barrier ()
176+
157177 def print (self , * args , ** kwargs ):
158178 self .fabric .print (* args , ** kwargs )
159179
@@ -165,35 +185,88 @@ def __call__(
165185 ):
166186 dl = cycle (self .dataloader )
167187
188+ # while less than required number of training steps
189+
168190 while self .steps < self .num_train_steps :
169191
192+ self .model .train ()
193+
194+ # gradient accumulation
195+
170196 for grad_accum_step in range (self .grad_accum_every ):
171197 is_accumulating = grad_accum_step < (self .grad_accum_every - 1 )
172198
173199 inputs = next (dl )
174200
175201 with self .fabric .no_backward_sync (self .model , enabled = is_accumulating ):
202+
203+ # model forwards
204+
176205 loss , loss_breakdown = self .model (
177206 ** inputs ,
178207 return_loss_breakdown = True
179208 )
180209
210+ # backwards
211+
181212 self .fabric .backward (loss / self .grad_accum_every )
182213
214+ # log entire loss breakdown
215+
183216 self .log (** loss_breakdown ._asdict ())
184217
185218 self .print (f'loss: { loss .item ():.3f} ' )
186219
220+ # clip gradients
221+
187222 self .fabric .clip_gradients (self .model , self .optimizer , max_norm = self .clip_grad_norm )
188223
224+ # optimizer step
225+
189226 self .optimizer .step ()
190227
228+ # update exponential moving average
229+
230+ self .wait ()
231+
191232 if self .is_main :
192233 self .ema_model .update ()
193234
235+ self .wait ()
236+
237+ # scheduler
238+
194239 self .scheduler .step ()
195240 self .optimizer .zero_grad ()
196241
197242 self .steps += 1
198243
244+ # maybe validate, for now, only on main with EMA model
245+
246+ if (
247+ self .is_main and
248+ self .needs_valid and
249+ divisible_by (self .steps , self .valid_every )
250+ ):
251+ with torch .no_grad ():
252+ self .ema_model .eval ()
253+
254+ total_valid_loss = 0.
255+
256+ for valid_batch in self .valid_dataloader :
257+ valid_loss , valid_loss_breakdown = self .ema_model (
258+ ** valid_batch ,
259+ return_loss_breakdown = True
260+ )
261+
262+ valid_batch_size = valid_batch .get ('atom_inputs' ).shape [0 ]
263+ scale = valid_batch_size / self .valid_dataset_size
264+
265+ scaled_valid_loss = valid_loss .item () * scale
266+ total_valid_loss += scaled_valid_loss
267+
268+ self .print (f'valid loss: { valid_loss .item ():.3f} ' )
269+
270+ self .wait ()
271+
199272 print (f'training complete' )
0 commit comments