@@ -81,8 +81,10 @@ def reset(self):
8181 state ['exp_avg' ] = torch .zeros_like (p )
8282
8383 if factored :
84- state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype )
85- state ['exp_avg_sq_col' ] = torch .zeros (grad_shape [:- 2 ] + grad_shape [- 1 :], dtype = grad .dtype )
84+ state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype , device = grad .device )
85+ state ['exp_avg_sq_col' ] = torch .zeros (
86+ grad_shape [:- 2 ] + grad_shape [- 1 :], dtype = grad .dtype , device = grad .device
87+ )
8688 else :
8789 state ['exp_avg_sq' ] = torch .zeros_like (grad )
8890
@@ -145,8 +147,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
145147 state ['exp_avg' ] = torch .zeros_like (p )
146148
147149 if factored :
148- state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype )
149- state ['exp_avg_sq_col' ] = torch .zeros (grad_shape [:- 2 ] + grad_shape [- 1 :], dtype = grad .dtype )
150+ state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype , device = grad .device )
151+ state ['exp_avg_sq_col' ] = torch .zeros (
152+ grad_shape [:- 2 ] + grad_shape [- 1 :], dtype = grad .dtype , device = grad .device
153+ )
150154 else :
151155 state ['exp_avg_sq' ] = torch .zeros_like (grad )
152156
@@ -170,7 +174,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
170174 self .approximate_sq_grad (exp_avg_sq_row , exp_avg_sq_col , update )
171175 else :
172176 exp_avg_sq = state ['exp_avg_sq' ]
173-
174177 exp_avg_sq .mul_ (beta2_t ).add_ (update , alpha = 1.0 - beta2_t )
175178 torch .rsqrt (exp_avg_sq , out = update )
176179
0 commit comments