3636 weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
3737 See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
3838 use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True.
39- split_qkv: Whether parameter is fused attention parameters (QKV, GQA, etc.), default to be False.
40- is_qkv_fn: Function to check if a parameter is fused attention parameters (QKV, GQA, etc.).
41- qkv_split_shapes: For grouped attention parameters (QKV, GQA, etc.), specify the shapes as a tuple of 3 integers
42- representing the sizes of Q, K, V components along the first dimension.
4339 fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
4440"""
4541
@@ -48,7 +44,8 @@ class OrthogonalizedOptimizer(optim.Optimizer):
4844 """Base class for orthogonalized optimizers.
4945
5046 This class is a wrapper around a base optimizer that performs orthogonalization on the updates.
51- The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers:
47+ The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the
48+ following papers:
5249
5350 - Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.*
5451 In International Conference on Artificial Intelligence and Statistics (2015a).
@@ -62,15 +59,33 @@ class OrthogonalizedOptimizer(optim.Optimizer):
6259 arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 <https://arxiv.org/abs/1708.00523>`_]
6360
6461 Note:
65- Orthogonalizing QKV sperately when they are fused is supported but with limitations. User must provide
66- a function to check if a weight tensor is fused attention parameters (QKV, GQA, etc.) as well as the
67- leading dimension of Q, K, V components. Only one split size is supported, i.e. all attention layers across
68- the network must have the same size.
62+ OrthogonalizedOptimizer as base class doesn't directly support orthogonalizing fused parameters separately.
63+ Subclass can override the orthogonalize function to support this, see example below.
64+
65+ .. code-block:: python
66+ :caption: Split QKV example
67+
68+ class SplitQkvOrthogonalizedOptimizer(OrthogonalizedOptimizer):
69+ def __init__(..., split_qkv_shapes):
70+ super().__init__(...)
71+ self.qkv_split_shapes = split_qkv_shapes
72+
73+ def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:
74+
75+ # Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the
76+ # scaled_orthogonalize_fn.
77+ if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False):
78+ qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0)
79+ qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads]
80+ grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized])
81+ else:
82+ grad = self.scaled_orthogonalize_fn(grad)
83+
84+ return grad
6985
7086 Args:
7187 {_args_doc}
72- orthogonalize_fn: Function to orthogonalize the updates.
73- scale_factor_fn: Function to compute the scale factor for the update.
88+ scaled_orthogonalize_fn: Function to orthogonalize and scale the updates.
7489 **kwargs: Arguments passed through to the base optimizer.
7590
7691 Note:
@@ -85,40 +100,13 @@ def __init__(
85100 use_nesterov : bool ,
86101 weight_decay : float ,
87102 use_decoupled_weight_decay : bool ,
88- split_qkv : bool ,
89- is_qkv_fn : Callable [[torch .Tensor ], bool ] | None ,
90- qkv_split_shapes : tuple [int , int , int ] | None ,
91103 fp32_matmul_prec : str ,
92- orthogonalize_fn : Callable | None = None ,
93- scale_factor_fn : Callable | None = None ,
104+ scaled_orthogonalize_fn : Callable | None = None ,
94105 ** kwargs : Any ,
95106 ):
96- if orthogonalize_fn is None :
97- logging .warning ("orthogonalize_fn not provided. Using noop" )
98- orthogonalize_fn = torch .nn .Identity ()
99-
100- if scale_factor_fn is None :
101- logging .warning ("scale_factor_fn not provided. Using default scale_factor_fn." )
102-
103- def return_one (* args , ** kwargs ): # type: ignore[no-untyped-def]
104- return 1.0
105-
106- scale_factor_fn = return_one
107-
108- if split_qkv :
109- assert is_qkv_fn is not None , "is_qkv_fn must be provided when split_qkv is True"
110- assert qkv_split_shapes is not None , "qkv_split_shapes must be provided when split_qkv is True"
111- if len (qkv_split_shapes ) != 3 :
112- raise ValueError (
113- f"qkv_split_shapes must be a tuple of 3 integers, got { len (qkv_split_shapes )} elements"
114- )
115- if not all (isinstance (s , int ) for s in qkv_split_shapes ):
116- raise ValueError (f"All elements in qkv_split_shapes must be integers, got { qkv_split_shapes } " )
117- if any (s <= 0 for s in qkv_split_shapes ):
118- raise ValueError (f"All elements in qkv_split_shapes must be positive, got { qkv_split_shapes } " )
119- self .split_qkv = split_qkv
120- self .is_qkv_fn = is_qkv_fn
121- self .qkv_split_shapes = qkv_split_shapes
107+ if scaled_orthogonalize_fn is None :
108+ logging .warning ("scaled_orthogonalize_fn not provided. Using noop" )
109+ scaled_orthogonalize_fn = torch .nn .Identity ()
122110
123111 self .fp32_matmul_prec = fp32_matmul_prec
124112 default_args_dict = dict (
@@ -131,8 +119,7 @@ def return_one(*args, **kwargs): # type: ignore[no-untyped-def]
131119 )
132120
133121 super ().__init__ (params , default_args_dict )
134- self .orthogonalize_fn = orthogonalize_fn
135- self .scale_factor_fn = scale_factor_fn
122+ self .scaled_orthogonalize_fn = scaled_orthogonalize_fn
136123
137124 @torch .no_grad () # type: ignore[misc]
138125 @override
@@ -182,36 +169,34 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
182169 grad = exp_avg
183170
184171 with utils .fp32_matmul_precision (self .fp32_matmul_prec ):
185- grad = self .orthogonalize (p , grad )
172+ group_kwargs = {k : v for k , v in group .items () if k != "params" }
173+ grad = self .orthogonalize (p , grad , ** group_kwargs )
186174
187175 # perform weight update
188176 # scale is applied to have update RMS == 1
189177 p .add_ (grad , alpha = - group ["lr" ])
190178
191179 return loss
192180
193- def orthogonalize (self , p : torch .Tensor , grad : torch .Tensor ) -> torch .Tensor :
181+ def orthogonalize (self , p : torch .Tensor , grad : torch .Tensor , ** kwargs : Any ) -> torch .Tensor :
194182 """Orthogonalize the momentum.
195183
184+ The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can
185+ override this function to implement different orthogonalization logic as well as split fused parameters.
186+ For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if
187+ the parameter is a fused parameter and should be split for preconditioning.
188+
196189 Args:
197- p: The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of
198- information is only available in the param tensor, attributes for example.
190+ p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of
191+ information is only available in the param tensor, attributes for example. Although not used in
192+ this default orthogonalize function.
199193 grad: The momentum tensor.
194+ **kwargs: keyword arguments of the param_group that p was belonged to.
200195
201196 Returns:
202197 The orthogonalized gradient tensor.
203198 """
204- if self .split_qkv and self .is_qkv_fn (p ): # type: ignore[misc]
205- logging .log_first_n (logging .INFO , f"split qkv with { p .shape } to { self .qkv_split_shapes } " , 1 )
206- # split grouped attention parameters (e.g., QKV, GQA, etc.)
207- qkv_grads = torch .split (grad , self .qkv_split_shapes , dim = 0 )
208- # Apply Newton-Schulz to each component
209- qkv_whitened = [self .orthogonalize_fn (g ) for g in qkv_grads ]
210- qkv_scales = [self .scale_factor_fn (g .size (0 ), g .size (1 )) for g in qkv_grads ]
211- # Apply individual scales to each component and concatenate
212- grad = torch .cat ([whitened * scale for whitened , scale in zip (qkv_whitened , qkv_scales )])
213- else :
214- grad = self .orthogonalize_fn (grad ) * self .scale_factor_fn (grad .size (0 ), grad .size (1 ))
199+ grad = self .scaled_orthogonalize_fn (grad )
215200 return grad
216201
217202
0 commit comments