@@ -80,7 +80,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8080 loss = closure ()
8181
8282 for group in self .param_groups :
83- momentum = group ['momentum' ]
83+ momentum , weight_decay = group ['momentum' ], group [ 'weight_decay ' ]
8484 for p in group ['params' ]:
8585 if p .grad is None :
8686 continue
@@ -100,29 +100,26 @@ def step(self, closure: CLOSURE = None) -> LOSS:
100100 state [f'pre_cond_{ dim_id } ' ] = self .matrix_eps * torch .eye (dim , out = grad .new (dim , dim ))
101101 state [f'inv_pre_cond_{ dim_id } ' ] = grad .new (dim , dim ).zero_ ()
102102
103- state ['step' ] += 1
104-
105103 if momentum > 0.0 :
106104 grad .mul_ (1.0 - momentum ).add_ (state ['momentum_buffer' ], alpha = momentum )
107105
108- if group [ ' weight_decay' ] > 0.0 :
109- grad .add_ (p , alpha = group [ ' weight_decay' ] )
106+ if weight_decay > 0.0 :
107+ grad .add_ (p , alpha = weight_decay )
110108
111109 order : int = grad .ndimension ()
112110 original_size : int = grad .size ()
113111 for dim_id , dim in enumerate (grad .size ()):
114- pre_cond = state [f'pre_cond_{ dim_id } ' ]
115- inv_pre_cond = state [f'inv_pre_cond_{ dim_id } ' ]
112+ pre_cond , inv_pre_cond = state [f'pre_cond_{ dim_id } ' ], state [f'inv_pre_cond_{ dim_id } ' ]
116113
117114 grad = grad .transpose_ (0 , dim_id ).contiguous ()
118115 transposed_size = grad .size ()
119116
120117 grad = grad .view (dim , - 1 )
121-
122118 grad_t = grad .t ()
119+
123120 pre_cond .add_ (grad @ grad_t )
124121 if state ['step' ] % self .preconditioning_compute_steps == 0 :
125- inv_pre_cond . copy_ ( compute_power_svd (pre_cond , - 1.0 / order ) )
122+ inv_pre_cond = compute_power_svd (pre_cond , - 1.0 / order )
126123
127124 if dim_id == order - 1 :
128125 grad = grad_t @ inv_pre_cond
@@ -131,6 +128,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
131128 grad = inv_pre_cond @ grad
132129 grad = grad .view (transposed_size )
133130
131+ state ['step' ] += 1
134132 state ['momentum_buffer' ] = grad
135133
136134 p .add_ (grad , alpha = - group ['lr' ])
0 commit comments