@@ -146,18 +146,26 @@ def __init__(
146146
147147 self .clip_grad_norm = clip_grad_norm
148148
149+ # steps
150+
151+ self .steps = 0
152+
149153 @property
150154 def is_main (self ):
151155 return self .fabric .global_rank == 0
152156
157+ def print (self , * args , ** kwargs ):
158+ self .fabric .print (* args , ** kwargs )
159+
160+ def log (self , ** log_data ):
161+ self .fabric .log_dict (log_data , step = self .steps )
162+
153163 def __call__ (
154164 self
155165 ):
156- dl = iter (self .dataloader )
157-
158- steps = 0
166+ dl = cycle (self .dataloader )
159167
160- while steps < self .num_train_steps :
168+ while self . steps < self .num_train_steps :
161169
162170 for grad_accum_step in range (self .grad_accum_every ):
163171 is_accumulating = grad_accum_step < (self .grad_accum_every - 1 )
@@ -169,7 +177,9 @@ def __call__(
169177
170178 self .fabric .backward (loss / self .grad_accum_every )
171179
172- print (f'loss: { loss .item ():.3f} ' )
180+ self .log (loss = loss )
181+
182+ self .print (f'loss: { loss .item ():.3f} ' )
173183
174184 self .fabric .clip_gradients (self .model , self .optimizer , max_norm = self .clip_grad_norm )
175185
@@ -181,6 +191,6 @@ def __call__(
181191 self .scheduler .step ()
182192 self .optimizer .zero_grad ()
183193
184- steps += 1
194+ self . steps += 1
185195
186196 print (f'training complete' )
0 commit comments