@@ -144,24 +144,26 @@ def __init__(
144144 self .total_iterations : int = num_epochs * num_batches_per_epoch
145145 if not self .total_iterations :
146146 raise ValueError (
147- 'missing total iterations, which is calculated from num epochs and num iterations per epoch param'
147+ 'missing total iterations, '
148+ 'calculated from num epochs and num iterations per epoch param'
148149 )
149150
150- # lr
151151 self .starting_lr = lr
152152 self .current_lr = lr
153153
154- # warmup - we'll use default recommended in Ma/Yarats unless user specifies num iterations
154+ # warmup - we'll use default recommended in Ma/Yarats
155+ # unless user specifies num iterations
155156 self .use_warmup = use_warmup
156- self .warmup_complete = False
157157 self .warmup_type = warmup_type
158158 self .warmup_pct_default = warmup_pct_default
159+ self .warmup_complete : bool = False
159160
160161 if num_warmup_iterations is None :
161162 beta_warmup_iterations : int = math .ceil ((2 / (1 - betas [1 ])))
162163 beta_pct : float = beta_warmup_iterations / self .total_iterations
163164
164- # this can be unreasonable for short runs...so let's compare vs warmup pct % of total epochs
165+ # this can be unreasonable for short runs...
166+ # so let's compare vs warmup pct % of total epochs
165167 if beta_pct > 0.45 :
166168 warmup_auto_pct = int (
167169 self .warmup_pct_default * self .total_iterations
@@ -351,7 +353,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
351353 loss = closure ()
352354
353355 param_size : float = 0
354- variance_ma_sum : float = 0 .0
356+ variance_ma_sum : float = 1 .0
355357
356358 # phase 1 - accumulate all of the variance_ma_sum to use in stable weight decay
357359 for i , group in enumerate (self .param_groups ):
0 commit comments