Skip to content

Commit aca76b6

Browse files
kylevedderkozistr
andauthored
[Feature] Store SOAP condition matrices as the dtype of their parameters (#335)
* Added parameter dtype support throughout the conditioner code * Update pytorch_optimizer/optimizer/soap.py * Update pytorch_optimizer/optimizer/soap.py * Update pytorch_optimizer/optimizer/soap.py --------- Co-authored-by: Kyle Vedder <[email protected]> Co-authored-by: Hyeongchan Kim <[email protected]>
1 parent 4b439ab commit aca76b6

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

pytorch_optimizer/optimizer/soap.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def project(
102102

103103
for mat in state['Q']:
104104
if len(mat) > 0:
105-
grad = torch.tensordot(grad, mat.to(grad.dtype), dims=[[0], [0 if project_type == 'forward' else 1]])
105+
grad = torch.tensordot(grad, mat, dims=[[0], [0 if project_type == 'forward' else 1]])
106106
else:
107107
grad = grad.permute([*list(range(1, len(grad.shape))), 0])
108108

@@ -123,9 +123,11 @@ def get_orthogonal_matrix(mat: torch.Tensor) -> List[torch.Tensor]:
123123
continue
124124

125125
try:
126-
_, q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device))
126+
_, q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
127127
except Exception: # pragma: no cover
128-
_, q = torch.linalg.eigh(m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device))
128+
_, q = torch.linalg.eigh(
129+
m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=torch.float64)
130+
)
129131
q = q.to(m.dtype)
130132

131133
q = torch.flip(q, dims=[1])
@@ -156,7 +158,13 @@ def get_orthogonal_matrix_qr(self, state, max_precondition_dim: int = 10000, mer
156158

157159
power_iter = m @ o[:, sort_idx]
158160

159-
q, _ = torch.linalg.qr(power_iter)
161+
# Compute QR decomposition
162+
# We cast to float32 because:
163+
# - torch.linalg.qr does not have support for types like bfloat16 as of PyTorch 2.5.1
164+
# - the correctness / numerical stability of the Q orthogonalization is important for the stability
165+
# of the optimizer
166+
q, _ = torch.linalg.qr(power_iter.to(torch.float32))
167+
q = q.to(power_iter.dtype)
160168

161169
matrices.append(q)
162170

@@ -185,7 +193,7 @@ def init_pre_conditioner(
185193
if not precondition_1d or grad.shape[0] > max_precondition_dim:
186194
state['GG'].append([])
187195
else:
188-
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device))
196+
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
189197
else:
190198
if merge_dims:
191199
grad = grad.reshape(merge_small_dims(grad.size(), max_precondition_dim))
@@ -194,7 +202,7 @@ def init_pre_conditioner(
194202
if sh > max_precondition_dim:
195203
state['GG'].append([])
196204
else:
197-
state['GG'].append(torch.zeros(sh, sh, device=grad.device))
205+
state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
198206

199207
state['Q'] = None
200208
state['precondition_frequency'] = precondition_frequency

0 commit comments

Comments
 (0)