Skip to content

Commit c7496b0

Browse files
authored
Support alternative precision training (#333)
Update matrix to use the dtype of the gradient during projection to fix error. Internal representation of float32 is maintained for precision purposes.
1 parent 81acb1e commit c7496b0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_optimizer/optimizer/soap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def project(
102102

103103
for mat in state['Q']:
104104
if len(mat) > 0:
105-
grad = torch.tensordot(grad, mat, dims=[[0], [0 if project_type == 'forward' else 1]])
105+
grad = torch.tensordot(grad, mat.to(grad.dtype), dims=[[0], [0 if project_type == 'forward' else 1]])
106106
else:
107107
grad = grad.permute([*list(range(1, len(grad.shape))), 0])
108108

0 commit comments

Comments
 (0)