File tree Expand file tree Collapse file tree 2 files changed +6
-4
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -145,8 +145,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
145145 state ['exp_avg' ] = torch .zeros_like (p )
146146
147147 if factored :
148- state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ]).to (grad )
149- state ['exp_avg_sq_col' ] = torch .zeros (grad_shape [:- 2 ] + grad_shape [- 1 :]).to (grad )
148+ state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype , device = grad .device )
149+ state ['exp_avg_sq_col' ] = torch .zeros (
150+ grad_shape [:- 2 ] + grad_shape [- 1 :], dtype = grad .dtype , device = grad .device
151+ )
150152 else :
151153 state ['exp_avg_sq' ] = torch .zeros_like (grad )
152154
Original file line number Diff line number Diff line change @@ -98,13 +98,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9898 state ['momentum_buffer' ] = torch .zeros_like (p )
9999
100100 if grad .is_sparse :
101- state ['accumulator_0' ] = torch .zeros (shape [0 ], device = grad .device )
101+ state ['accumulator_0' ] = torch .zeros (shape [0 ], dtype = grad . dtype , device = grad .device )
102102 elif rank == 0 :
103103 state ['accumulator_0' ] = torch .zeros_like (p )
104104 else :
105105 for i in range (rank ):
106106 state [f'accumulator_{ i } ' ] = torch .zeros (
107- [1 ] * i + [shape [i ]] + [1 ] * (rank - 1 - i ), device = grad .device
107+ [1 ] * i + [shape [i ]] + [1 ] * (rank - 1 - i ), dtype = grad . dtype , device = grad .device
108108 )
109109
110110 state ['step' ] += 1
You can’t perform that action at this time.
0 commit comments