|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +import math |
| 16 | +from typing import Callable, List, Tuple, override |
| 17 | + |
| 18 | +import torch |
| 19 | +from torch.optim.optimizer import ParamsT |
| 20 | + |
| 21 | +from emerging_optimizers.psgd.procrustes_step import procrustes_step |
| 22 | +from emerging_optimizers.psgd.psgd_kron_contractions import apply_preconditioner, partial_contraction |
| 23 | +from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_spd, uniformize_q_in_place |
| 24 | +from emerging_optimizers.soap.soap import _clip_update_rms_in_place |
| 25 | + |
| 26 | + |
| 27 | +__all__ = [ |
| 28 | + "PSGDPro", |
| 29 | +] |
| 30 | + |
| 31 | + |
| 32 | +class PSGDPro(torch.optim.Optimizer): |
| 33 | + """Implements a variant of the PSGD optimization algorithm (PSGD-Kron-Whiten with Procrustes step for preconditioner update). |
| 34 | +
|
| 35 | + Preconditioned Stochastic Gradient Descent (PSGD) (https://arxiv.org/abs/1512.04202) is a preconditioned optimization algorithm |
| 36 | + that fits amplitudes of perturbations of preconditioned stochastic gradient to match that of the perturbations of parameters. |
| 37 | + PSGD with Kronecker-factored Preconditioner (PSGD-Kron-Whiten) is a variant of PSGD that reduces memory and computational complexity. |
| 38 | + Procrustes step is an algorithm to update the preconditioner which respects a particular geometry: Q^0.5 * E * Q^1.5, see Stochastic Hessian |
| 39 | + Fittings with Lie Groups (https://arxiv.org/abs/2402.11858) for more details. |
| 40 | +
|
| 41 | + Args: |
| 42 | + params: Iterable of parameters to optimize or dicts defining parameter groups |
| 43 | + lr: The learning rate to use |
| 44 | + weight_decay: Weight decay coefficient |
| 45 | + use_decoupled_weight_decay: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization: |
| 46 | + https://arxiv.org/abs/1711.05101. |
| 47 | + momentum: Momentum coefficient for exponential moving average of gradient. |
| 48 | + beta_lip: EMA beta for the Lipschitz constants. |
| 49 | + precond_lr: Inner learning rate for the preconditioner. |
| 50 | + precond_init_scale: scale of initial preconditioner values. |
| 51 | + min_precond_lr: Minimum learning rate for preconditioner learning rate schedule. |
| 52 | + warmup_steps: Warmup steps for preconditioner learning rate schedule. |
| 53 | + damping_noise_scale: scale of dampening noise added to gradients. |
| 54 | + max_update_rms: Clip the update RMS to this value (0 means no clipping). |
| 55 | + """ |
| 56 | + |
| 57 | + def __init__( |
| 58 | + self, |
| 59 | + params: ParamsT, |
| 60 | + lr: float = 3e-3, |
| 61 | + weight_decay: float = 0.01, |
| 62 | + use_decoupled_weight_decay: bool = True, |
| 63 | + momentum: float = 0.9, |
| 64 | + beta_lip: float = 0.9, |
| 65 | + precond_lr: float = 0.1, |
| 66 | + precond_init_scale: float = 1.0, |
| 67 | + damping_noise_scale: float = 0.1, |
| 68 | + min_precond_lr: float = 0.01, |
| 69 | + warmup_steps: int = 10000, |
| 70 | + max_update_rms: float = 0.0, |
| 71 | + ) -> None: |
| 72 | + defaults = { |
| 73 | + "lr": lr, |
| 74 | + "beta_lip": beta_lip, |
| 75 | + "weight_decay": weight_decay, |
| 76 | + "use_decoupled_weight_decay": use_decoupled_weight_decay, |
| 77 | + "momentum": momentum, |
| 78 | + "precond_lr": precond_lr, |
| 79 | + "precond_init_scale": precond_init_scale, |
| 80 | + "max_update_rms": max_update_rms, |
| 81 | + "min_precond_lr": min_precond_lr, |
| 82 | + "warmup_steps": warmup_steps, |
| 83 | + "damping_noise_scale": damping_noise_scale, |
| 84 | + } |
| 85 | + super().__init__(params, defaults) |
| 86 | + |
| 87 | + @torch.no_grad() # type: ignore[misc] |
| 88 | + @override |
| 89 | + def step(self, closure: Callable[[], float] | None = None) -> float | None: |
| 90 | + """Performs a single optimization step. |
| 91 | +
|
| 92 | + Args: |
| 93 | + closure: A closure that reevaluates the model and returns the loss. |
| 94 | + """ |
| 95 | + if closure is None: |
| 96 | + loss = None |
| 97 | + else: |
| 98 | + loss = closure() |
| 99 | + |
| 100 | + for group in self.param_groups: |
| 101 | + for p in group["params"]: |
| 102 | + if p.grad is None: |
| 103 | + continue |
| 104 | + grad = p.grad |
| 105 | + state = self.state[p] |
| 106 | + |
| 107 | + # Optimizer state initialization |
| 108 | + if "step" not in state: |
| 109 | + state["step"] = 0 |
| 110 | + # Momentum buffer |
| 111 | + if "exp_avg" not in state: |
| 112 | + state["exp_avg"] = torch.zeros_like(grad) |
| 113 | + # PSGD kronecker factor matrices and Lipschitz constants initialization |
| 114 | + if "Q" not in state or "L" not in state: |
| 115 | + state["Q"], state["L"] = _init_psgd_kron_states( |
| 116 | + grad, |
| 117 | + precond_init_scale=group["precond_init_scale"], |
| 118 | + ) |
| 119 | + |
| 120 | + # weight decay |
| 121 | + if group["weight_decay"] > 0.0: |
| 122 | + if group["use_decoupled_weight_decay"]: |
| 123 | + # Apply decoupled weight decay |
| 124 | + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) |
| 125 | + else: |
| 126 | + # add l2 regularization before preconditioning (i.e. adding a squared loss term) |
| 127 | + grad += group["weight_decay"] * p |
| 128 | + |
| 129 | + # update momentum buffer with EMA of gradient |
| 130 | + exp_avg = state["exp_avg"] |
| 131 | + exp_avg.lerp_(grad, 1 - group["momentum"]) |
| 132 | + |
| 133 | + # Get hyperparameters for preconditioner update |
| 134 | + damping_noise_scale = group["damping_noise_scale"] |
| 135 | + precond_lr = _get_precond_lr( |
| 136 | + group["precond_lr"], state["step"], group["min_precond_lr"], group["warmup_steps"] |
| 137 | + ) |
| 138 | + |
| 139 | + beta_lip = group["beta_lip"] |
| 140 | + # Preconditioner update |
| 141 | + state["Q"], state["L"] = _update_precond_procrustes( |
| 142 | + state["Q"], state["L"], exp_avg, damping_noise_scale, precond_lr, beta_lip |
| 143 | + ) |
| 144 | + uniformize_q_in_place(state["Q"]) |
| 145 | + |
| 146 | + # Get weight update by preconditioning the momentum |
| 147 | + update = apply_preconditioner(state["Q"], exp_avg) |
| 148 | + _clip_update_rms_in_place(update, group["max_update_rms"]) |
| 149 | + |
| 150 | + # Apply weight update |
| 151 | + p.add_(update, alpha=-group["lr"]) |
| 152 | + |
| 153 | + return loss |
| 154 | + |
| 155 | + |
| 156 | +def _init_psgd_kron_states( |
| 157 | + grad: torch.Tensor, |
| 158 | + precond_init_scale: float = 1.0, |
| 159 | +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
| 160 | + """Initialize the Kronecker factor matrices and Lipschitz constants. |
| 161 | +
|
| 162 | + Args: |
| 163 | + grad: Gradient tensor. |
| 164 | + precond_init_scale: Scale of preconditioner initialization. |
| 165 | +
|
| 166 | + Returns: |
| 167 | + q_list: List of Kronecker factors. |
| 168 | + lip_const_list: List of Lipschitz constants for the Kronecker factors. |
| 169 | + """ |
| 170 | + q_list: List[torch.Tensor] = [] |
| 171 | + lip_const_list: List[torch.Tensor] = [] |
| 172 | + |
| 173 | + # Create identity matrices scaled by precond_init_scale for each dimension |
| 174 | + for size in grad.shape: |
| 175 | + q_list.append(torch.eye(size, device=grad.device) * precond_init_scale) |
| 176 | + lip_const_list.append(torch.ones((), device=grad.device)) |
| 177 | + |
| 178 | + return q_list, lip_const_list |
| 179 | + |
| 180 | + |
| 181 | +def _update_precond_procrustes( |
| 182 | + q_list: List[torch.Tensor], |
| 183 | + lip_const_list: List[torch.Tensor], |
| 184 | + exp_avg: torch.Tensor, |
| 185 | + damping_noise_scale: float = 1e-9, |
| 186 | + precond_lr: float = 0.1, |
| 187 | + beta_lip: float = 0.9, |
| 188 | +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
| 189 | + r"""Update the Kron preconditioner Q using procrustes step and uniformization. |
| 190 | +
|
| 191 | + Args: |
| 192 | + q_list: List of Kronecker factors. |
| 193 | + lip_const_list: List of Lipschitz constants for the Kronecker factors. |
| 194 | + exp_avg: Exponential moving average of gradient. |
| 195 | + damping_noise_scale: Scale of noise added to gradient. |
| 196 | + precond_lr: Learning rate. |
| 197 | + beta_lip: EMA beta for the Lipschitz constant. |
| 198 | +
|
| 199 | + Returns: |
| 200 | + q_list: List of Kronecker factors. |
| 201 | + lip_const_list: List of Lipschitz constants for the Kronecker factors. |
| 202 | + """ |
| 203 | + dampened_momentum = exp_avg + (damping_noise_scale + 1e-7 * exp_avg.abs()) * torch.randn_like(exp_avg) |
| 204 | + pg = apply_preconditioner(q_list, dampened_momentum) |
| 205 | + total_numel = pg.numel() |
| 206 | + updated_q_list: List[torch.Tensor] = [] |
| 207 | + updated_lip_const_list: List[torch.Tensor] = [] |
| 208 | + for dim, q in enumerate(q_list): |
| 209 | + # compute gradient covariance |
| 210 | + precond_grad_cov = partial_contraction(pg, pg, dim) |
| 211 | + if q.dim() < 2: |
| 212 | + # diagonal or scalar-structured preconditioner |
| 213 | + q, updated_lip_const = _update_1d_preconditioner( |
| 214 | + q, lip_const_list[dim], precond_grad_cov, total_numel, precond_lr, beta_lip |
| 215 | + ) |
| 216 | + else: |
| 217 | + # matrix-structured preconditioner |
| 218 | + q, updated_lip_const = _update_matrix_preconditioner( |
| 219 | + q, lip_const_list[dim], precond_grad_cov, total_numel, precond_lr, beta_lip |
| 220 | + ) |
| 221 | + updated_q_list.append(q) |
| 222 | + updated_lip_const_list.append(updated_lip_const) |
| 223 | + |
| 224 | + return updated_q_list, updated_lip_const_list |
| 225 | + |
| 226 | + |
| 227 | +def _update_matrix_preconditioner( |
| 228 | + q: torch.Tensor, |
| 229 | + lip_const: torch.Tensor, |
| 230 | + precond_grad_cov: torch.Tensor, |
| 231 | + total_numel: int, |
| 232 | + precond_lr: float, |
| 233 | + beta_lip: float, |
| 234 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 235 | + r"""Update matrix-structured preconditioner with adaptive Lipschitz constant. |
| 236 | +
|
| 237 | + Args: |
| 238 | + q: Kronecker factor matrix for this dimension to update. |
| 239 | + lip_const: Lipschitz constant for this dimension. |
| 240 | + precond_grad_cov: Gradient covariance. |
| 241 | + total_numel: Total number of elements in the gradient. |
| 242 | + precond_lr: Learning rate. |
| 243 | + beta_lip: EMA beta for the Lipschitz constant. |
| 244 | +
|
| 245 | + Returns: |
| 246 | + q: Updated Kronecker factor matrix for this dimension. |
| 247 | + lip_const: Updated Lipschitz constant for this dimension. |
| 248 | + """ |
| 249 | + normalization = total_numel / q.shape[0] |
| 250 | + ell = norm_lower_bound_spd(precond_grad_cov) + normalization |
| 251 | + lip_const = torch.max(beta_lip * lip_const + (1 - beta_lip) * ell, ell) |
| 252 | + q = q - precond_lr / lip_const * (precond_grad_cov @ q - normalization * q) |
| 253 | + q = procrustes_step(q) |
| 254 | + return q, lip_const |
| 255 | + |
| 256 | + |
| 257 | +def _update_1d_preconditioner( |
| 258 | + q: torch.Tensor, |
| 259 | + lip_const: torch.Tensor, |
| 260 | + precond_grad_cov: torch.Tensor, |
| 261 | + total_numel: int, |
| 262 | + precond_lr: float, |
| 263 | + beta_lip: float, |
| 264 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 265 | + r"""Update 1D preconditioner with adaptive Lipschitz constant. |
| 266 | +
|
| 267 | + Args: |
| 268 | + q: Kronecker factor 1D tensor for this dimension to update. |
| 269 | + lip_const: Lipschitz constant for this dimension. |
| 270 | + precond_grad_cov: Gradient covariance. |
| 271 | + total_numel: Total number of elements in the gradient. |
| 272 | + precond_lr: Learning rate. |
| 273 | + beta_lip: EMA beta for the Lipschitz constant. |
| 274 | +
|
| 275 | + Returns: |
| 276 | + q: Updated Kronecker factor 1D tensor for this dimension. |
| 277 | + lip_const: Updated Lipschitz constant for this dimension. |
| 278 | + """ |
| 279 | + normalization = total_numel / q.numel() |
| 280 | + ell = torch.max(precond_grad_cov) + normalization |
| 281 | + lip_const = torch.max(beta_lip * lip_const + (1 - beta_lip) * ell, ell) |
| 282 | + q = q * (1 - precond_lr / lip_const * (precond_grad_cov - normalization)) |
| 283 | + return q, lip_const |
| 284 | + |
| 285 | + |
| 286 | +def _get_precond_lr(precond_lr: float, step: int, min_precond_lr: float = 0.01, warmup_steps: int = 10000) -> float: |
| 287 | + r"""Helper function to get preconditioner learning rate for this optimization step based on a square root schedule. |
| 288 | +
|
| 289 | + Decaying from a higher lr down to min_precond_lr improves accuracy. |
| 290 | +
|
| 291 | + Args: |
| 292 | + precond_lr: Learning rate. |
| 293 | + step: Current step. |
| 294 | + min_precond_lr: Minimum learning rate. |
| 295 | + warmup_steps: Warmup steps. |
| 296 | +
|
| 297 | + Returns: |
| 298 | + The preconditioner learning rate. |
| 299 | + """ |
| 300 | + |
| 301 | + scheduled_lr = precond_lr / math.sqrt(1.0 + step / warmup_steps) |
| 302 | + return max(scheduled_lr, min_precond_lr) |
0 commit comments