Skip to content

Commit cc2891e

Browse files
committed
fix: qr decompose for bfloat16
1 parent bd9e1e6 commit cc2891e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_optimizer/optimizer/racs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,15 +211,15 @@ def subspace_iteration(
211211
return torch.linalg.eigh(u.T @ a @ u)
212212

213213
def switch(self, q: torch.Tensor, u_prev: torch.Tensor, rank: int, leading_basis: int) -> torch.Tensor:
214-
vals, vecs = self.subspace_iteration(q, u_prev, num_steps=1)
214+
vals, vecs = self.subspace_iteration(q.to(torch.float32), u_prev.to(torch.float32), num_steps=1)
215215

216216
leading_indices = torch.argsort(vals, descending=True)[:leading_basis]
217217
u_t1 = vecs[:, leading_indices]
218218

219219
u_c, _ = torch.linalg.qr(torch.eye(q.shape[0], device=q.device) - u_t1 @ u_t1.T)
220220
u_t2 = u_c[:, :rank - leading_basis] # fmt: skip
221221

222-
return torch.cat([u_t1, u_t2], dim=1)
222+
return torch.cat([u_t1, u_t2], dim=1).to(q.dtype)
223223

224224
@staticmethod
225225
def compensation(

0 commit comments

Comments
 (0)