|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
| 15 | +from functools import partial |
15 | 16 | from itertools import chain |
16 | 17 | from typing import Callable, Iterable, List, Optional, Tuple, Union |
17 | 18 |
|
@@ -197,35 +198,34 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: |
197 | 198 | # Exponential moving average of squared gradient values |
198 | 199 | state["exp_avg_sq"] = torch.zeros_like(grad) |
199 | 200 |
|
| 201 | + if "Q" not in state: |
| 202 | + state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape] |
| 203 | + |
| 204 | + # Define kronecker_factor_update_fn based on whether to use KL-Shampoo here |
| 205 | + # because it needs access to state and group |
| 206 | + kronecker_factor_update_fn = partial( |
| 207 | + update_kronecker_factors, precondition_1d=group["precondition_1d"] |
| 208 | + ) |
| 209 | + if group["use_kl_shampoo"]: |
| 210 | + kronecker_factor_update_fn = partial( |
| 211 | + update_kronecker_factors_kl_shampoo, |
| 212 | + eigenbasis_list=state["Q"], |
| 213 | + eps=group["eps"], |
| 214 | + ) |
| 215 | + |
200 | 216 | # Initialize kronecker factor matrices |
201 | 217 | if "GG" not in state: |
202 | 218 | state["GG"] = init_kronecker_factors( |
203 | 219 | grad, |
204 | 220 | precondition_1d=group["precondition_1d"], |
205 | 221 | ) |
206 | 222 |
|
207 | | - assert "Q" not in state, "Q should not be initialized yet" |
208 | | - state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape] |
209 | | - |
210 | 223 | # Update preconditioner matrices with gradient statistics, |
211 | 224 | # do not use shampoo_beta for EMA at first step |
212 | 225 | with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): |
213 | | - kronecker_factor_update_kwargs = dict( |
214 | | - kronecker_factor_list=state["GG"], |
215 | | - grad=grad, |
216 | | - shampoo_beta=0.0, |
| 226 | + kronecker_factor_update_fn( |
| 227 | + kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=group["shampoo_beta"] |
217 | 228 | ) |
218 | | - if group["use_kl_shampoo"]: |
219 | | - update_kronecker_factors_kl_shampoo( |
220 | | - **kronecker_factor_update_kwargs, |
221 | | - eigenbasis_list=state["Q"], |
222 | | - eps=group["eps"], |
223 | | - ) |
224 | | - else: |
225 | | - update_kronecker_factors( |
226 | | - **kronecker_factor_update_kwargs, |
227 | | - precondition_1d=group["precondition_1d"], |
228 | | - ) |
229 | 229 |
|
230 | 230 | # Increment step counter |
231 | 231 | state["step"] += 1 |
@@ -299,21 +299,11 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: |
299 | 299 |
|
300 | 300 | torch.cuda.nvtx.range_push("update_kronecker_factors") |
301 | 301 | with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): |
302 | | - if group["use_kl_shampoo"]: |
303 | | - update_kronecker_factors_kl_shampoo( |
304 | | - kronecker_factor_list=state["GG"], |
305 | | - grad=grad, |
306 | | - shampoo_beta=shampoo_beta, |
307 | | - eigenbasis_list=state["Q"], |
308 | | - eps=group["eps"], |
309 | | - ) |
310 | | - else: |
311 | | - update_kronecker_factors( |
312 | | - kronecker_factor_list=state["GG"], |
313 | | - grad=grad, |
314 | | - shampoo_beta=shampoo_beta, |
315 | | - precondition_1d=group["precondition_1d"], |
316 | | - ) |
| 302 | + kronecker_factor_update_fn( |
| 303 | + kronecker_factor_list=state["GG"], |
| 304 | + grad=grad, |
| 305 | + shampoo_beta=0.0, |
| 306 | + ) |
317 | 307 | torch.cuda.nvtx.range_pop() |
318 | 308 |
|
319 | 309 | # If current step is the last step to skip preconditioning, initialize eigenbases and |
|
0 commit comments