|
36 | 36 | weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. |
37 | 37 | See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 |
38 | 38 | use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True. |
39 | | - split_fused: Whether to split fused parameters (QKV, GQA, etc.) for preconditioning, default to be False. |
40 | | - is_fused_fn: Function to check if a parameter is fused parameters (QKV, GQA, etc.). |
41 | | - If multiple types of parameters are fused, the function should return True for all of which needs to be |
42 | | - split for preconditioning. |
43 | | - split_fn: Function to split the fused parameters (QKV, GQA, etc.) into a list of parameters. |
44 | | - It should support all the types of parameters that is_fused_fn returns True for. |
45 | 39 | fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. |
46 | 40 | """ |
47 | 41 |
|
@@ -70,8 +64,7 @@ class OrthogonalizedOptimizer(optim.Optimizer): |
70 | 64 |
|
71 | 65 | Args: |
72 | 66 | {_args_doc} |
73 | | - orthogonalize_fn: Function to orthogonalize the updates. |
74 | | - scale_factor_fn: Function to compute the scale factor for the update. |
| 67 | + scaled_orthogonalize_fn: Function to orthogonalize and scale the updates. |
75 | 68 | **kwargs: Arguments passed through to the base optimizer. |
76 | 69 |
|
77 | 70 | Note: |
@@ -155,21 +148,29 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: |
155 | 148 | grad = exp_avg |
156 | 149 |
|
157 | 150 | with utils.fp32_matmul_precision(self.fp32_matmul_prec): |
158 | | - grad = self.orthogonalize(p, grad) |
| 151 | + group_kwargs = {k: v for k, v in group.items() if k != "params"} |
| 152 | + grad = self.orthogonalize(p, grad, **group_kwargs) |
159 | 153 |
|
160 | 154 | # perform weight update |
161 | 155 | # scale is applied to have update RMS == 1 |
162 | 156 | p.add_(grad, alpha=-group["lr"]) |
163 | 157 |
|
164 | 158 | return loss |
165 | 159 |
|
166 | | - def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor: |
| 160 | + def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
167 | 161 | """Orthogonalize the momentum. |
168 | 162 |
|
| 163 | + The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can |
| 164 | + override this function to implement different orthogonalization logic as well as split fused parameters. |
| 165 | + For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if |
| 166 | + the parameter is a fused parameter and should be split for preconditioning. |
| 167 | +
|
169 | 168 | Args: |
170 | | - p: The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of |
171 | | - information is only available in the param tensor, attributes for example. |
| 169 | + p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of |
| 170 | + information is only available in the param tensor, attributes for example. Although not used in |
| 171 | + this default orthogonalize function. |
172 | 172 | grad: The momentum tensor. |
| 173 | + **kwargs: keyword arguments of the param_group that p was belonged to. |
173 | 174 |
|
174 | 175 | Returns: |
175 | 176 | The orthogonalized gradient tensor. |
|
0 commit comments