Skip to content

Commit 0d18523

Browse files
committed
fix: matmul
1 parent d8853d1 commit 0d18523

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

pytorch_optimizer/optimizer/galore.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,21 @@ def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
119119
) * self.scale
120120
if self.projection_type == 'reverse_std':
121121
return (
122-
torch.matmul(self.ortho_matrix, low_rank_grad)
122+
torch.matmul(self.ortho_matrix, low_rank_grad.t())
123123
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]
124-
else torch.matmul(low_rank_grad, self.ortho_matrix)
124+
else torch.matmul(low_rank_grad, self.ortho_matrix.t())
125125
) * self.scale
126126
if self.projection_type == 'right':
127-
return torch.matmul(low_rank_grad, self.ortho_matrix) * self.scale
127+
return torch.matmul(low_rank_grad, self.ortho_matrix.t()) * self.scale
128128
if self.projection_type == 'left':
129-
return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale
129+
return torch.matmul(self.ortho_matrix, low_rank_grad.t()) * self.scale
130130
if self.projection_type == 'full':
131-
return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1] * self.scale
131+
return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale
132132

133133
raise NotImplementedError
134134

135135

136-
class GaLoreOptimizer(Optimizer, BaseOptimizer):
136+
class GaLore(Optimizer, BaseOptimizer):
137137
r"""AdamW optimizer with GaLore projector.
138138
139139
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
@@ -209,7 +209,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
209209

210210
state = self.state[p]
211211

212-
if 'rank' in group:
212+
if len(state) == 0:
213+
state['exp_avg'] = torch.zeros_like(p)
214+
state['exp_avg_sq'] = torch.zeros_like(p)
215+
216+
if 'rank' in group and p.dim() > 1:
213217
if 'projector' not in state:
214218
state['projector'] = GaLoreProjector(
215219
rank=group['rank'],
@@ -220,10 +224,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
220224

221225
grad = state['projector'].project(grad, group['step'])
222226

223-
if len(state) == 0:
224-
state['exp_avg'] = torch.zeros_like(p)
225-
state['exp_avg_sq'] = torch.zeros_like(p)
226-
227227
self.apply_weight_decay(
228228
p=p,
229229
grad=None,
@@ -241,7 +241,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
241241

242242
norm_grad = exp_avg / de_nom
243243

244-
if 'rank' in group:
244+
if 'rank' in group and p.dim() > 1:
245245
norm_grad = state['projector'].project_back(norm_grad)
246246

247247
p.add_(norm_grad, alpha=-step_size)

0 commit comments

Comments
 (0)