@@ -119,21 +119,21 @@ def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
119119 ) * self .scale
120120 if self .projection_type == 'reverse_std' :
121121 return (
122- torch .matmul (self .ortho_matrix , low_rank_grad )
122+ torch .matmul (self .ortho_matrix , low_rank_grad . t () )
123123 if low_rank_grad .shape [0 ] <= low_rank_grad .shape [1 ]
124- else torch .matmul (low_rank_grad , self .ortho_matrix )
124+ else torch .matmul (low_rank_grad , self .ortho_matrix . t () )
125125 ) * self .scale
126126 if self .projection_type == 'right' :
127- return torch .matmul (low_rank_grad , self .ortho_matrix ) * self .scale
127+ return torch .matmul (low_rank_grad , self .ortho_matrix . t () ) * self .scale
128128 if self .projection_type == 'left' :
129- return torch .matmul (self .ortho_matrix , low_rank_grad ) * self .scale
129+ return torch .matmul (self .ortho_matrix , low_rank_grad . t () ) * self .scale
130130 if self .projection_type == 'full' :
131- return torch .matmul (self .ortho_matrix [0 ], low_rank_grad ) @ self .ortho_matrix [1 ] * self .scale
131+ return torch .matmul (self .ortho_matrix [0 ], low_rank_grad ) @ self .ortho_matrix [1 ]. t () * self .scale
132132
133133 raise NotImplementedError
134134
135135
136- class GaLoreOptimizer (Optimizer , BaseOptimizer ):
136+ class GaLore (Optimizer , BaseOptimizer ):
137137 r"""AdamW optimizer with GaLore projector.
138138
139139 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
@@ -209,7 +209,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
209209
210210 state = self .state [p ]
211211
212- if 'rank' in group :
212+ if len (state ) == 0 :
213+ state ['exp_avg' ] = torch .zeros_like (p )
214+ state ['exp_avg_sq' ] = torch .zeros_like (p )
215+
216+ if 'rank' in group and p .dim () > 1 :
213217 if 'projector' not in state :
214218 state ['projector' ] = GaLoreProjector (
215219 rank = group ['rank' ],
@@ -220,10 +224,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
220224
221225 grad = state ['projector' ].project (grad , group ['step' ])
222226
223- if len (state ) == 0 :
224- state ['exp_avg' ] = torch .zeros_like (p )
225- state ['exp_avg_sq' ] = torch .zeros_like (p )
226-
227227 self .apply_weight_decay (
228228 p = p ,
229229 grad = None ,
@@ -241,7 +241,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
241241
242242 norm_grad = exp_avg / de_nom
243243
244- if 'rank' in group :
244+ if 'rank' in group and p . dim () > 1 :
245245 norm_grad = state ['projector' ].project_back (norm_grad )
246246
247247 p .add_ (norm_grad , alpha = - step_size )
0 commit comments