@@ -81,8 +81,8 @@ def reset(self):
8181 state ['exp_avg' ] = torch .zeros_like (p )
8282
8383 if factored :
84- state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad . dtype )
85- state ['exp_avg_sq_col' ] = torch .zeros (grad_shape [:- 2 ] + grad_shape [- 1 :], dtype = grad . dtype )
84+ state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ]). to ( grad )
85+ state ['exp_avg_sq_col' ] = torch .zeros (grad_shape [:- 2 ] + grad_shape [- 1 :]). to ( grad )
8686 else :
8787 state ['exp_avg_sq' ] = torch .zeros_like (grad )
8888
@@ -114,8 +114,8 @@ def approximate_sq_grad(
114114 exp_avg_sq_col : torch .Tensor ,
115115 output : torch .Tensor ,
116116 ):
117- r"""Get approximate squared gradient."""
118- r_factor : torch .Tensor = (exp_avg_sq_row / exp_avg_sq_row .mean (dim = - 1 )).rsqrt_ ().unsqueeze (- 1 )
117+ r"""Get approximation of EMA of squared gradient."""
118+ r_factor : torch .Tensor = (exp_avg_sq_row / exp_avg_sq_row .mean (dim = - 1 , keepdim = True )).rsqrt_ ().unsqueeze (- 1 )
119119 c_factor : torch .Tensor = exp_avg_sq_col .unsqueeze (- 2 ).rsqrt ()
120120 torch .mul (r_factor , c_factor , out = output )
121121
@@ -145,8 +145,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
145145 state ['exp_avg' ] = torch .zeros_like (p )
146146
147147 if factored :
148- state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype )
149- state ['exp_avg_sq_col' ] = torch .zeros (grad_shape [:- 2 ] + grad_shape [- 1 :], dtype = grad .dtype )
148+ state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype , device = grad .device )
149+ state ['exp_avg_sq_col' ] = torch .zeros (
150+ grad_shape [:- 2 ] + grad_shape [- 1 :], dtype = grad .dtype , device = grad .device
151+ )
150152 else :
151153 state ['exp_avg_sq' ] = torch .zeros_like (grad )
152154
@@ -166,11 +168,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
166168 exp_avg_sq_row .mul_ (beta2_t ).add_ (update .mean (dim = - 1 ), alpha = 1.0 - beta2_t )
167169 exp_avg_sq_col .mul_ (beta2_t ).add_ (update .mean (dim = - 2 ), alpha = 1.0 - beta2_t )
168170
169- # Approximation of exponential moving average of square of gradient
170- self .approximate_sq_grad (exp_avg_sq_row , exp_avg_sq_col , update )
171+ self .approximate_sq_grad (exp_avg_sq_row , exp_avg_sq_col , output = update )
171172 else :
172173 exp_avg_sq = state ['exp_avg_sq' ]
173-
174174 exp_avg_sq .mul_ (beta2_t ).add_ (update , alpha = 1.0 - beta2_t )
175175 torch .rsqrt (exp_avg_sq , out = update )
176176
0 commit comments