-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathorthogonalized_optimizer.py
More file actions
210 lines (170 loc) · 8.83 KB
/
orthogonalized_optimizer.py
File metadata and controls
210 lines (170 loc) · 8.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, overload, override
import torch
import torch.optim as optim
from absl import logging
from torch.optim.optimizer import ParamsT
from emerging_optimizers import mixin as opt_mixin
from emerging_optimizers import utils
from emerging_optimizers.utils import FP32MatmulPrecT
_args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups
lr: The learning rate used by the internal SGD.
momentum_beta: The momentum used by the internal SGD.
weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
use_nesterov: Whether to use Nesterov-style momentum in the internal SGD.
weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin`
for more details.
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
"""
class OrthogonalizedOptimizer(opt_mixin.WeightDecayMixin, optim.Optimizer):
"""Base class for orthogonalized optimizers.
This class is a wrapper around a base optimizer that performs orthogonalization on the updates.
The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the
following papers:
- Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.*
In International Conference on Artificial Intelligence and Statistics (2015a).
- Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V.
*Stochastic Spectral Descent for Discrete Graphical Models.*
In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016).
- Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V.
*Preconditioned spectral descent for deep learning.*
In Neural Information Processing Systems (2015b).
- Flynn, T. *The duality structure gradient descent algorithm: analysis and applications to neural networks.*
arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 <https://arxiv.org/abs/1708.00523>`_]
Note:
OrthogonalizedOptimizer as base class doesn't directly support orthogonalizing fused parameters separately.
Subclass can override the orthogonalize function to support this, see example below.
.. code-block:: python
:caption: Split QKV example
class SplitQkvOrthogonalizedOptimizer(OrthogonalizedOptimizer):
def __init__(..., split_qkv_shapes):
super().__init__(...)
self.qkv_split_shapes = split_qkv_shapes
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:
# Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the
# scaled_orthogonalize_fn.
if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False):
qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0)
qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads]
grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized])
else:
grad = self.scaled_orthogonalize_fn(grad)
return grad
Args:
{_args_doc}
scaled_orthogonalize_fn: Function to orthogonalize and scale the updates.
**kwargs: Arguments passed through to the base optimizer.
Note:
Keyword arguments passed through are not checked here. Optimizer inherited from this class should check them.
"""
def __init__(
self,
params: ParamsT,
lr: float,
momentum_beta: float,
weight_decay: float,
*,
use_nesterov: bool,
weight_decay_method: opt_mixin.WeightDecayT,
fp32_matmul_prec: FP32MatmulPrecT,
scaled_orthogonalize_fn: Callable | None = None,
**kwargs: Any,
):
if scaled_orthogonalize_fn is None:
logging.warning("scaled_orthogonalize_fn not provided. Using noop")
scaled_orthogonalize_fn = torch.nn.Identity()
self.fp32_matmul_prec = fp32_matmul_prec
self.use_nesterov = use_nesterov
self.weight_decay_method = weight_decay_method
default_args_dict = dict(
lr=lr,
momentum_beta=momentum_beta,
weight_decay=weight_decay,
**kwargs,
)
super().__init__(params, default_args_dict)
self.scaled_orthogonalize_fn = scaled_orthogonalize_fn
@overload
def step(self, closure: None = ...) -> None: ...
@overload
def step(self, closure: Callable[[], float]) -> float: ...
@torch.no_grad() # type: ignore[misc]
@override
def step(self, closure: Callable[[], float] | None = None) -> float | None:
"""Performs a single optimization step.
Args:
closure: A closure that reevaluates the model and returns the loss.
"""
if closure is None:
loss = None
else:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
grad = p.grad
if grad is None:
continue
state = self.state[p]
# initialize momentum buffer
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
# Subsequent update to exp_avg are all inplace, so it is not assigned back to state.
exp_avg = state["momentum_buffer"]
self._apply_weight_decay_inplace(
p,
grad,
group["lr"],
group["weight_decay"],
)
# update momentum buffer with EMA of gradient
exp_avg.lerp_(grad, 1 - group["momentum_beta"])
# include nesterov momentum
if self.use_nesterov:
grad = grad.lerp(exp_avg, group["momentum_beta"])
else:
grad = exp_avg
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
group_kwargs = {k: v for k, v in group.items() if k != "params"}
orth_grad = self.orthogonalize(p, grad, **group_kwargs)
# perform weight update
# scale is applied to have update RMS == 1
p.add_(orth_grad, alpha=-group["lr"])
return loss
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Orthogonalize the momentum.
The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can
override this function to implement different orthogonalization logic as well as split fused parameters.
For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if
the parameter is a fused parameter and should be split for preconditioning.
Note:
N-D parameters can be supported by overriding this function. For example, convolution weight can be
supported by reshaping to [output_channels, input_channels * kernel_height * kernel_width], i.e. treating
convolution as matrix multiplication with im2col.
Args:
p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of
information is only available in the param tensor, attributes for example. Although not used in
this default orthogonalize function.
grad: The momentum tensor.
**kwargs: keyword arguments of the param_group that p was belonged to.
Returns:
The orthogonalized gradient tensor.
"""
if grad.ndim != 2:
raise ValueError("Only 2D parameters are supported.")
grad = self.scaled_orthogonalize_fn(grad)
return grad
OrthogonalizedOptimizer.__doc__ = OrthogonalizedOptimizer.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]