Skip to content

Commit f49d04e

Browse files
PSGD-Kron-Pro(crustes) optimizer implementation (#60)
* added psgd with convergence test, removed torch.compile from partial contraction Signed-off-by: mikail <[email protected]> * added psgd convergence test to ci Signed-off-by: mikail <[email protected]> * updated docstring of psgd Signed-off-by: mikail <[email protected]> * added PSGD's init and more documentation Signed-off-by: mikail <[email protected]> * addressed MR comments for style Signed-off-by: mikail <[email protected]> * changed argument name in convergence test Signed-off-by: mikail <[email protected]> * addressed mr comments Signed-off-by: mikail <[email protected]> * removed double hyperparam initialization Signed-off-by: mikail <[email protected]> * changed convergence test to per element MSE instead, previous was global Signed-off-by: mikail <[email protected]> * changed subspace iteration to dim of 32 as per xilin's suggestion for fp32 Signed-off-by: mikail <[email protected]> * changed dampening to match Xilin's suggestion Signed-off-by: mikail <[email protected]> * replaced torch.finfo with explicit value Signed-off-by: mikail <[email protected]> --------- Signed-off-by: mikail <[email protected]>
1 parent b724bb7 commit f49d04e

File tree

6 files changed

+532
-5
lines changed

6 files changed

+532
-5
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from emerging_optimizers.psgd.psgd import *

emerging_optimizers/psgd/psgd.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import math
16+
from typing import Callable, List, Tuple, override
17+
18+
import torch
19+
from torch.optim.optimizer import ParamsT
20+
21+
from emerging_optimizers.psgd.procrustes_step import procrustes_step
22+
from emerging_optimizers.psgd.psgd_kron_contractions import apply_preconditioner, partial_contraction
23+
from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_spd, uniformize_q_in_place
24+
from emerging_optimizers.soap.soap import _clip_update_rms_in_place
25+
26+
27+
__all__ = [
28+
"PSGDPro",
29+
]
30+
31+
32+
class PSGDPro(torch.optim.Optimizer):
33+
"""Implements a variant of the PSGD optimization algorithm (PSGD-Kron-Whiten with Procrustes step for preconditioner update).
34+
35+
Preconditioned Stochastic Gradient Descent (PSGD) (https://arxiv.org/abs/1512.04202) is a preconditioned optimization algorithm
36+
that fits amplitudes of perturbations of preconditioned stochastic gradient to match that of the perturbations of parameters.
37+
PSGD with Kronecker-factored Preconditioner (PSGD-Kron-Whiten) is a variant of PSGD that reduces memory and computational complexity.
38+
Procrustes step is an algorithm to update the preconditioner which respects a particular geometry: Q^0.5 * E * Q^1.5, see Stochastic Hessian
39+
Fittings with Lie Groups (https://arxiv.org/abs/2402.11858) for more details.
40+
41+
Args:
42+
params: Iterable of parameters to optimize or dicts defining parameter groups
43+
lr: The learning rate to use
44+
weight_decay: Weight decay coefficient
45+
use_decoupled_weight_decay: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization:
46+
https://arxiv.org/abs/1711.05101.
47+
momentum: Momentum coefficient for exponential moving average of gradient.
48+
beta_lip: EMA beta for the Lipschitz constants.
49+
precond_lr: Inner learning rate for the preconditioner.
50+
precond_init_scale: scale of initial preconditioner values.
51+
min_precond_lr: Minimum learning rate for preconditioner learning rate schedule.
52+
warmup_steps: Warmup steps for preconditioner learning rate schedule.
53+
damping_noise_scale: scale of dampening noise added to gradients.
54+
max_update_rms: Clip the update RMS to this value (0 means no clipping).
55+
"""
56+
57+
def __init__(
58+
self,
59+
params: ParamsT,
60+
lr: float = 3e-3,
61+
weight_decay: float = 0.01,
62+
use_decoupled_weight_decay: bool = True,
63+
momentum: float = 0.9,
64+
beta_lip: float = 0.9,
65+
precond_lr: float = 0.1,
66+
precond_init_scale: float = 1.0,
67+
damping_noise_scale: float = 0.1,
68+
min_precond_lr: float = 0.01,
69+
warmup_steps: int = 10000,
70+
max_update_rms: float = 0.0,
71+
) -> None:
72+
defaults = {
73+
"lr": lr,
74+
"beta_lip": beta_lip,
75+
"weight_decay": weight_decay,
76+
"use_decoupled_weight_decay": use_decoupled_weight_decay,
77+
"momentum": momentum,
78+
"precond_lr": precond_lr,
79+
"precond_init_scale": precond_init_scale,
80+
"max_update_rms": max_update_rms,
81+
"min_precond_lr": min_precond_lr,
82+
"warmup_steps": warmup_steps,
83+
"damping_noise_scale": damping_noise_scale,
84+
}
85+
super().__init__(params, defaults)
86+
87+
@torch.no_grad() # type: ignore[misc]
88+
@override
89+
def step(self, closure: Callable[[], float] | None = None) -> float | None:
90+
"""Performs a single optimization step.
91+
92+
Args:
93+
closure: A closure that reevaluates the model and returns the loss.
94+
"""
95+
if closure is None:
96+
loss = None
97+
else:
98+
loss = closure()
99+
100+
for group in self.param_groups:
101+
for p in group["params"]:
102+
if p.grad is None:
103+
continue
104+
grad = p.grad
105+
state = self.state[p]
106+
107+
# Optimizer state initialization
108+
if "step" not in state:
109+
state["step"] = 0
110+
# Momentum buffer
111+
if "exp_avg" not in state:
112+
state["exp_avg"] = torch.zeros_like(grad)
113+
# PSGD kronecker factor matrices and Lipschitz constants initialization
114+
if "Q" not in state or "L" not in state:
115+
state["Q"], state["L"] = _init_psgd_kron_states(
116+
grad,
117+
precond_init_scale=group["precond_init_scale"],
118+
)
119+
120+
# weight decay
121+
if group["weight_decay"] > 0.0:
122+
if group["use_decoupled_weight_decay"]:
123+
# Apply decoupled weight decay
124+
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
125+
else:
126+
# add l2 regularization before preconditioning (i.e. adding a squared loss term)
127+
grad += group["weight_decay"] * p
128+
129+
# update momentum buffer with EMA of gradient
130+
exp_avg = state["exp_avg"]
131+
exp_avg.lerp_(grad, 1 - group["momentum"])
132+
133+
# Get hyperparameters for preconditioner update
134+
damping_noise_scale = group["damping_noise_scale"]
135+
precond_lr = _get_precond_lr(
136+
group["precond_lr"], state["step"], group["min_precond_lr"], group["warmup_steps"]
137+
)
138+
139+
beta_lip = group["beta_lip"]
140+
# Preconditioner update
141+
state["Q"], state["L"] = _update_precond_procrustes(
142+
state["Q"], state["L"], exp_avg, damping_noise_scale, precond_lr, beta_lip
143+
)
144+
uniformize_q_in_place(state["Q"])
145+
146+
# Get weight update by preconditioning the momentum
147+
update = apply_preconditioner(state["Q"], exp_avg)
148+
_clip_update_rms_in_place(update, group["max_update_rms"])
149+
150+
# Apply weight update
151+
p.add_(update, alpha=-group["lr"])
152+
153+
return loss
154+
155+
156+
def _init_psgd_kron_states(
157+
grad: torch.Tensor,
158+
precond_init_scale: float = 1.0,
159+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
160+
"""Initialize the Kronecker factor matrices and Lipschitz constants.
161+
162+
Args:
163+
grad: Gradient tensor.
164+
precond_init_scale: Scale of preconditioner initialization.
165+
166+
Returns:
167+
q_list: List of Kronecker factors.
168+
lip_const_list: List of Lipschitz constants for the Kronecker factors.
169+
"""
170+
q_list: List[torch.Tensor] = []
171+
lip_const_list: List[torch.Tensor] = []
172+
173+
# Create identity matrices scaled by precond_init_scale for each dimension
174+
for size in grad.shape:
175+
q_list.append(torch.eye(size, device=grad.device) * precond_init_scale)
176+
lip_const_list.append(torch.ones((), device=grad.device))
177+
178+
return q_list, lip_const_list
179+
180+
181+
def _update_precond_procrustes(
182+
q_list: List[torch.Tensor],
183+
lip_const_list: List[torch.Tensor],
184+
exp_avg: torch.Tensor,
185+
damping_noise_scale: float = 1e-9,
186+
precond_lr: float = 0.1,
187+
beta_lip: float = 0.9,
188+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
189+
r"""Update the Kron preconditioner Q using procrustes step and uniformization.
190+
191+
Args:
192+
q_list: List of Kronecker factors.
193+
lip_const_list: List of Lipschitz constants for the Kronecker factors.
194+
exp_avg: Exponential moving average of gradient.
195+
damping_noise_scale: Scale of noise added to gradient.
196+
precond_lr: Learning rate.
197+
beta_lip: EMA beta for the Lipschitz constant.
198+
199+
Returns:
200+
q_list: List of Kronecker factors.
201+
lip_const_list: List of Lipschitz constants for the Kronecker factors.
202+
"""
203+
dampened_momentum = exp_avg + (damping_noise_scale + 1e-7 * exp_avg.abs()) * torch.randn_like(exp_avg)
204+
pg = apply_preconditioner(q_list, dampened_momentum)
205+
total_numel = pg.numel()
206+
updated_q_list: List[torch.Tensor] = []
207+
updated_lip_const_list: List[torch.Tensor] = []
208+
for dim, q in enumerate(q_list):
209+
# compute gradient covariance
210+
precond_grad_cov = partial_contraction(pg, pg, dim)
211+
if q.dim() < 2:
212+
# diagonal or scalar-structured preconditioner
213+
q, updated_lip_const = _update_1d_preconditioner(
214+
q, lip_const_list[dim], precond_grad_cov, total_numel, precond_lr, beta_lip
215+
)
216+
else:
217+
# matrix-structured preconditioner
218+
q, updated_lip_const = _update_matrix_preconditioner(
219+
q, lip_const_list[dim], precond_grad_cov, total_numel, precond_lr, beta_lip
220+
)
221+
updated_q_list.append(q)
222+
updated_lip_const_list.append(updated_lip_const)
223+
224+
return updated_q_list, updated_lip_const_list
225+
226+
227+
def _update_matrix_preconditioner(
228+
q: torch.Tensor,
229+
lip_const: torch.Tensor,
230+
precond_grad_cov: torch.Tensor,
231+
total_numel: int,
232+
precond_lr: float,
233+
beta_lip: float,
234+
) -> Tuple[torch.Tensor, torch.Tensor]:
235+
r"""Update matrix-structured preconditioner with adaptive Lipschitz constant.
236+
237+
Args:
238+
q: Kronecker factor matrix for this dimension to update.
239+
lip_const: Lipschitz constant for this dimension.
240+
precond_grad_cov: Gradient covariance.
241+
total_numel: Total number of elements in the gradient.
242+
precond_lr: Learning rate.
243+
beta_lip: EMA beta for the Lipschitz constant.
244+
245+
Returns:
246+
q: Updated Kronecker factor matrix for this dimension.
247+
lip_const: Updated Lipschitz constant for this dimension.
248+
"""
249+
normalization = total_numel / q.shape[0]
250+
ell = norm_lower_bound_spd(precond_grad_cov) + normalization
251+
lip_const = torch.max(beta_lip * lip_const + (1 - beta_lip) * ell, ell)
252+
q = q - precond_lr / lip_const * (precond_grad_cov @ q - normalization * q)
253+
q = procrustes_step(q)
254+
return q, lip_const
255+
256+
257+
def _update_1d_preconditioner(
258+
q: torch.Tensor,
259+
lip_const: torch.Tensor,
260+
precond_grad_cov: torch.Tensor,
261+
total_numel: int,
262+
precond_lr: float,
263+
beta_lip: float,
264+
) -> Tuple[torch.Tensor, torch.Tensor]:
265+
r"""Update 1D preconditioner with adaptive Lipschitz constant.
266+
267+
Args:
268+
q: Kronecker factor 1D tensor for this dimension to update.
269+
lip_const: Lipschitz constant for this dimension.
270+
precond_grad_cov: Gradient covariance.
271+
total_numel: Total number of elements in the gradient.
272+
precond_lr: Learning rate.
273+
beta_lip: EMA beta for the Lipschitz constant.
274+
275+
Returns:
276+
q: Updated Kronecker factor 1D tensor for this dimension.
277+
lip_const: Updated Lipschitz constant for this dimension.
278+
"""
279+
normalization = total_numel / q.numel()
280+
ell = torch.max(precond_grad_cov) + normalization
281+
lip_const = torch.max(beta_lip * lip_const + (1 - beta_lip) * ell, ell)
282+
q = q * (1 - precond_lr / lip_const * (precond_grad_cov - normalization))
283+
return q, lip_const
284+
285+
286+
def _get_precond_lr(precond_lr: float, step: int, min_precond_lr: float = 0.01, warmup_steps: int = 10000) -> float:
287+
r"""Helper function to get preconditioner learning rate for this optimization step based on a square root schedule.
288+
289+
Decaying from a higher lr down to min_precond_lr improves accuracy.
290+
291+
Args:
292+
precond_lr: Learning rate.
293+
step: Current step.
294+
min_precond_lr: Minimum learning rate.
295+
warmup_steps: Warmup steps.
296+
297+
Returns:
298+
The preconditioner learning rate.
299+
"""
300+
301+
scheduled_lr = precond_lr / math.sqrt(1.0 + step / warmup_steps)
302+
return max(scheduled_lr, min_precond_lr)

emerging_optimizers/psgd/psgd_kron_contractions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
]
2525

2626

27-
@torch.compile # type: ignore[misc]
2827
def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.Tensor:
2928
"""Compute the partial contraction of G1 and G2 along axis `axis`.
3029
This is the contraction of the two tensors, but with all axes except `axis` contracted.
@@ -38,10 +37,9 @@ def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.
3837
Tensor of shape (d_{axis}, d_{axis})
3938
"""
4039
# dims_to_contract = all dims except `axis`
41-
dims = list(range(G1.dim()))
42-
dims.pop(axis)
40+
dims_to_contract = [i for i in range(G1.dim()) if i != axis]
4341
# contraction is symmetric and has shape (d_{axis}, d_{axis})
44-
return torch.tensordot(G1, G2, dims=(dims, dims))
42+
return torch.tensordot(G1, G2, dims=(dims_to_contract, dims_to_contract))
4543

4644

4745
@torch.compile # type: ignore[misc]

emerging_optimizers/psgd/psgd_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def uniformize_q_in_place(Q_list: List[torch.Tensor]) -> None:
7070

7171

7272
@torch.compile # type: ignore[misc]
73-
def norm_lower_bound_spd(A: torch.Tensor, k: int = 4, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor:
73+
def norm_lower_bound_spd(A: torch.Tensor, k: int = 32, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor:
7474
r"""A cheap lower bound for the spectral norm of a symmetric positive definite matrix.
7575
7676

tests/ci/L0_Tests_GPU.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@ coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py
2929
coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda || error=1
3030
coverage run -p --source=emerging_optimizers tests/test_psgd_contractions.py --device=cuda || error=1
3131
coverage run -p --source=emerging_optimizers tests/test_psgd_utils.py --device=cuda || error=1
32+
coverage run -p --source=emerging_optimizers tests/test_psgd_convergence.py --device=cuda || error=1
3233

3334
exit "${error}"

0 commit comments

Comments
 (0)