@@ -102,7 +102,7 @@ def project(
102102
103103 for mat in state ['Q' ]:
104104 if len (mat ) > 0 :
105- grad = torch .tensordot (grad , mat . to ( grad . dtype ) , dims = [[0 ], [0 if project_type == 'forward' else 1 ]])
105+ grad = torch .tensordot (grad , mat , dims = [[0 ], [0 if project_type == 'forward' else 1 ]])
106106 else :
107107 grad = grad .permute ([* list (range (1 , len (grad .shape ))), 0 ])
108108
@@ -123,9 +123,11 @@ def get_orthogonal_matrix(mat: torch.Tensor) -> List[torch.Tensor]:
123123 continue
124124
125125 try :
126- _ , q = torch .linalg .eigh (m + 1e-30 * torch .eye (m .shape [0 ], device = m .device ))
126+ _ , q = torch .linalg .eigh (m + 1e-30 * torch .eye (m .shape [0 ], device = m .device , dtype = m . dtype ))
127127 except Exception : # pragma: no cover
128- _ , q = torch .linalg .eigh (m .to (torch .float64 ) + 1e-30 * torch .eye (m .shape [0 ], device = m .device ))
128+ _ , q = torch .linalg .eigh (
129+ m .to (torch .float64 ) + 1e-30 * torch .eye (m .shape [0 ], device = m .device , dtype = torch .float64 )
130+ )
129131 q = q .to (m .dtype )
130132
131133 q = torch .flip (q , dims = [1 ])
@@ -156,7 +158,13 @@ def get_orthogonal_matrix_qr(self, state, max_precondition_dim: int = 10000, mer
156158
157159 power_iter = m @ o [:, sort_idx ]
158160
159- q , _ = torch .linalg .qr (power_iter )
161+ # Compute QR decomposition
162+ # We cast to float32 because:
163+ # - torch.linalg.qr does not have support for types like bfloat16 as of PyTorch 2.5.1
164+ # - the correctness / numerical stability of the Q orthogonalization is important for the stability
165+ # of the optimizer
166+ q , _ = torch .linalg .qr (power_iter .to (torch .float32 ))
167+ q = q .to (power_iter .dtype )
160168
161169 matrices .append (q )
162170
@@ -185,7 +193,7 @@ def init_pre_conditioner(
185193 if not precondition_1d or grad .shape [0 ] > max_precondition_dim :
186194 state ['GG' ].append ([])
187195 else :
188- state ['GG' ].append (torch .zeros (grad .shape [0 ], grad .shape [0 ], device = grad .device ))
196+ state ['GG' ].append (torch .zeros (grad .shape [0 ], grad .shape [0 ], device = grad .device , dtype = grad . dtype ))
189197 else :
190198 if merge_dims :
191199 grad = grad .reshape (merge_small_dims (grad .size (), max_precondition_dim ))
@@ -194,7 +202,7 @@ def init_pre_conditioner(
194202 if sh > max_precondition_dim :
195203 state ['GG' ].append ([])
196204 else :
197- state ['GG' ].append (torch .zeros (sh , sh , device = grad .device ))
205+ state ['GG' ].append (torch .zeros (sh , sh , device = grad .device , dtype = grad . dtype ))
198206
199207 state ['Q' ] = None
200208 state ['precondition_frequency' ] = precondition_frequency
0 commit comments