@@ -81,10 +81,8 @@ def reset(self):
8181 state ['exp_avg' ] = torch .zeros_like (p )
8282
8383 if factored :
84- state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype , device = grad .device )
85- state ['exp_avg_sq_col' ] = torch .zeros (
86- grad_shape [:- 2 ] + grad_shape [- 1 :], dtype = grad .dtype , device = grad .device
87- )
84+ state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ]).to (grad )
85+ state ['exp_avg_sq_col' ] = torch .zeros (grad_shape [:- 2 ] + grad_shape [- 1 :]).to (grad )
8886 else :
8987 state ['exp_avg_sq' ] = torch .zeros_like (grad )
9088
@@ -116,8 +114,8 @@ def approximate_sq_grad(
116114 exp_avg_sq_col : torch .Tensor ,
117115 output : torch .Tensor ,
118116 ):
119- r"""Get approximate squared gradient."""
120- r_factor : torch .Tensor = (exp_avg_sq_row / exp_avg_sq_row .mean (dim = - 1 )).rsqrt_ ().unsqueeze (- 1 )
117+ r"""Get approximation of EMA of squared gradient."""
118+ r_factor : torch .Tensor = (exp_avg_sq_row / exp_avg_sq_row .mean (dim = - 1 , keepdim = True )).rsqrt_ ().unsqueeze (- 1 )
121119 c_factor : torch .Tensor = exp_avg_sq_col .unsqueeze (- 2 ).rsqrt ()
122120 torch .mul (r_factor , c_factor , out = output )
123121
@@ -147,10 +145,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
147145 state ['exp_avg' ] = torch .zeros_like (p )
148146
149147 if factored :
150- state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype , device = grad .device )
151- state ['exp_avg_sq_col' ] = torch .zeros (
152- grad_shape [:- 2 ] + grad_shape [- 1 :], dtype = grad .dtype , device = grad .device
153- )
148+ state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ]).to (grad )
149+ state ['exp_avg_sq_col' ] = torch .zeros (grad_shape [:- 2 ] + grad_shape [- 1 :]).to (grad )
154150 else :
155151 state ['exp_avg_sq' ] = torch .zeros_like (grad )
156152
@@ -170,8 +166,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
170166 exp_avg_sq_row .mul_ (beta2_t ).add_ (update .mean (dim = - 1 ), alpha = 1.0 - beta2_t )
171167 exp_avg_sq_col .mul_ (beta2_t ).add_ (update .mean (dim = - 2 ), alpha = 1.0 - beta2_t )
172168
173- # Approximation of exponential moving average of square of gradient
174- self .approximate_sq_grad (exp_avg_sq_row , exp_avg_sq_col , update )
169+ self .approximate_sq_grad (exp_avg_sq_row , exp_avg_sq_col , output = update )
175170 else :
176171 exp_avg_sq = state ['exp_avg_sq' ]
177172 exp_avg_sq .mul_ (beta2_t ).add_ (update , alpha = 1.0 - beta2_t )
0 commit comments