|
3 | 3 |
|
4 | 4 | import torch |
5 | 5 | from torch import nn |
| 6 | +from torch.distributed import all_gather, get_rank, get_world_size |
6 | 7 | from torch.optim import Optimizer |
7 | 8 |
|
8 | 9 | from pytorch_optimizer.base.exception import NoComplexParameterError, NoSparseGradientError |
@@ -216,6 +217,209 @@ def step(self, closure: CLOSURE = None) -> LOSS: |
216 | 217 | return loss |
217 | 218 |
|
218 | 219 |
|
| 220 | +class DistributedMuon(BaseOptimizer): # pragma: no cover |
| 221 | + r"""Distributed Momentum Orthogonalized by Newton-schulz. |
| 222 | +
|
| 223 | + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which |
| 224 | + each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each |
| 225 | + update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU. |
| 226 | +
|
| 227 | + Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and |
| 228 | + scalar or vector parameters should be optimized using AdamW. |
| 229 | +
|
| 230 | + Some warnings: |
| 231 | + - We believe this optimizer is unlikely to work well for training with small batch size. |
| 232 | + - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this. |
| 233 | +
|
| 234 | + Example: |
| 235 | + ------- |
| 236 | + from pytorch_optimizer import DistributedMuon |
| 237 | +
|
| 238 | + hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2] |
| 239 | + hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2] |
| 240 | + non_hidden_params = [*model.head.parameters(), *model.embed.parameters()] |
| 241 | +
|
| 242 | + param_groups = [ |
| 243 | + dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True), |
| 244 | + dict( |
| 245 | + params=hidden_gains_biases + non_hidden_params, |
| 246 | + lr=3e-4, |
| 247 | + betas=(0.9, 0.95), |
| 248 | + weight_decay=0.01, |
| 249 | + use_muon=False, |
| 250 | + ), |
| 251 | + ] |
| 252 | +
|
| 253 | + optimizer = Muon(param_groups) |
| 254 | +
|
| 255 | + :param params: PARAMETERS. the parameters to be optimized by Muon. |
| 256 | + :param lr: float. learning rate. |
| 257 | + :param momentum: float. the momentum used by the internal SGD. |
| 258 | + :param weight_decay: float. weight decay (L2 penalty). |
| 259 | + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. |
| 260 | + :param nesterov: bool. whether to use nesterov momentum. |
| 261 | + :param ns_steps: int. the number of Newton-Schulz iterations to run. (5 is probably always enough) |
| 262 | + :param use_adjusted_lr: bool. whether to use adjusted learning rate, which is from the Moonlight. |
| 263 | + reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py |
| 264 | + :param adamw_lr: float. The learning rate for the internal AdamW. |
| 265 | + :param adamw_betas: The betas for the internal AdamW. |
| 266 | + :param adamw_wd: float. The weight decay for the internal AdamW. |
| 267 | + :param adamw_eps: float. The epsilon for the internal AdamW. |
| 268 | + :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. |
| 269 | + """ |
| 270 | + |
| 271 | + def __init__( |
| 272 | + self, |
| 273 | + params: PARAMETERS, |
| 274 | + lr: float = 2e-2, |
| 275 | + momentum: float = 0.95, |
| 276 | + weight_decay: float = 0.0, |
| 277 | + weight_decouple: bool = True, |
| 278 | + nesterov: bool = True, |
| 279 | + ns_steps: int = 5, |
| 280 | + use_adjusted_lr: bool = False, |
| 281 | + adamw_lr: float = 3e-4, |
| 282 | + adamw_betas: BETAS = (0.9, 0.95), |
| 283 | + adamw_wd: float = 0.0, |
| 284 | + adamw_eps: float = 1e-10, |
| 285 | + maximize: bool = False, |
| 286 | + **kwargs, |
| 287 | + ): |
| 288 | + self.validate_learning_rate(lr) |
| 289 | + self.validate_learning_rate(adamw_lr) |
| 290 | + self.validate_non_negative(weight_decay, 'weight_decay') |
| 291 | + self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)') |
| 292 | + self.validate_positive(ns_steps, 'ns_steps') |
| 293 | + self.validate_betas(adamw_betas) |
| 294 | + self.validate_non_negative(adamw_wd, 'adamw_wd') |
| 295 | + self.validate_non_negative(adamw_eps, 'adamw_eps') |
| 296 | + |
| 297 | + self.maximize = maximize |
| 298 | + |
| 299 | + self.world_size: int = get_world_size() |
| 300 | + self.rank: int = get_rank() |
| 301 | + |
| 302 | + for group in params: |
| 303 | + if 'use_muon' not in group: |
| 304 | + raise ValueError('`use_muon` must be set.') |
| 305 | + |
| 306 | + if group['use_muon']: |
| 307 | + group['lr'] = group.get('lr', lr) |
| 308 | + group['momentum'] = group.get('momentum', momentum) |
| 309 | + group['nesterov'] = group.get('nesterov', nesterov) |
| 310 | + group['weight_decay'] = group.get('weight_decay', weight_decay) |
| 311 | + group['ns_steps'] = group.get('ns_steps', ns_steps) |
| 312 | + group['use_adjusted_lr'] = group.get('use_adjusted_lr', use_adjusted_lr) |
| 313 | + else: |
| 314 | + group['lr'] = group.get('lr', adamw_lr) |
| 315 | + group['betas'] = group.get('betas', adamw_betas) |
| 316 | + group['eps'] = group.get('eps', adamw_eps) |
| 317 | + group['weight_decay'] = group.get('weight_decay', adamw_wd) |
| 318 | + |
| 319 | + group['weight_decouple'] = group.get('weight_decouple', weight_decouple) |
| 320 | + |
| 321 | + super().__init__(params, kwargs) |
| 322 | + |
| 323 | + def __str__(self) -> str: |
| 324 | + return 'DistributedMuon' |
| 325 | + |
| 326 | + def init_group(self, group: GROUP, **kwargs) -> None: |
| 327 | + for p in group['params']: |
| 328 | + if p.grad is None: |
| 329 | + p.grad = torch.zeros_like(p) |
| 330 | + |
| 331 | + grad = p.grad |
| 332 | + if grad.is_sparse: |
| 333 | + raise NoSparseGradientError(str(self)) |
| 334 | + |
| 335 | + if torch.is_complex(p): |
| 336 | + raise NoComplexParameterError(str(self)) |
| 337 | + |
| 338 | + state = self.state[p] |
| 339 | + |
| 340 | + if len(state) == 0 and not group['use_muon']: |
| 341 | + state['exp_avg'] = torch.zeros_like(p) |
| 342 | + state['exp_avg_sq'] = torch.zeros_like(p) |
| 343 | + |
| 344 | + @torch.no_grad() |
| 345 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 346 | + loss: LOSS = None |
| 347 | + if closure is not None: |
| 348 | + with torch.enable_grad(): |
| 349 | + loss = closure() |
| 350 | + |
| 351 | + for group in self.param_groups: |
| 352 | + if 'step' not in group: |
| 353 | + self.init_group(group) |
| 354 | + group['step'] = 1 |
| 355 | + else: |
| 356 | + group['step'] += 1 |
| 357 | + |
| 358 | + if group['use_muon']: |
| 359 | + params = group['params'] |
| 360 | + padded_params = params + [torch.empty_like(params[-1])] * ( |
| 361 | + self.world_size - len(params) % self.world_size |
| 362 | + ) |
| 363 | + |
| 364 | + for i in range(len(params))[:: self.world_size]: |
| 365 | + if i + self.rank < len(params): |
| 366 | + p = params[i + self.rank] |
| 367 | + |
| 368 | + grad = p.grad |
| 369 | + |
| 370 | + self.maximize_gradient(grad, maximize=self.maximize) |
| 371 | + |
| 372 | + state = self.state[p] |
| 373 | + if len(state) == 0: |
| 374 | + state['momentum_buffer'] = torch.zeros_like(p) |
| 375 | + |
| 376 | + self.apply_weight_decay( |
| 377 | + p, |
| 378 | + grad=grad, |
| 379 | + lr=group['lr'], |
| 380 | + weight_decay=group['weight_decay'], |
| 381 | + weight_decouple=group['weight_decouple'], |
| 382 | + fixed_decay=False, |
| 383 | + ) |
| 384 | + |
| 385 | + buf = state['momentum_buffer'] |
| 386 | + buf.lerp_(grad, weight=1.0 - group['momentum']) |
| 387 | + |
| 388 | + update = grad.lerp_(buf, weight=group['momentum']) if group['nesterov'] else buf |
| 389 | + if update.ndim > 2: |
| 390 | + update = update.view(len(update), -1) |
| 391 | + |
| 392 | + update = zero_power_via_newton_schulz_5(update, num_steps=group['ns_steps']) |
| 393 | + |
| 394 | + if group.get('cautious'): |
| 395 | + self.apply_cautious(update, grad) |
| 396 | + |
| 397 | + lr: float = get_adjusted_lr(group['lr'], p.size(), use_adjusted_lr=group['use_adjusted_lr']) |
| 398 | + |
| 399 | + p.add_(update.reshape(p.shape), alpha=-lr) |
| 400 | + |
| 401 | + all_gather(padded_params[i:i + self.world_size], padded_params[i:i + self.rank]) # fmt: skip |
| 402 | + else: |
| 403 | + for p in group['params']: |
| 404 | + grad = p.grad |
| 405 | + |
| 406 | + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
| 407 | + |
| 408 | + beta1, beta2 = group['betas'] |
| 409 | + |
| 410 | + bias_correction1: float = self.debias(beta1, group['step']) |
| 411 | + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) |
| 412 | + |
| 413 | + exp_avg.lerp_(grad, weight=1.0 - beta1) |
| 414 | + exp_avg_sq.lerp_(grad.square(), weight=1.0 - beta2) |
| 415 | + |
| 416 | + de_nom = exp_avg_sq.sqrt().add_(group['eps']).div_(bias_correction2_sq) |
| 417 | + |
| 418 | + p.addcdiv_(exp_avg / bias_correction1, de_nom, value=-group['lr']) |
| 419 | + |
| 420 | + return loss |
| 421 | + |
| 422 | + |
219 | 423 | class AdaMuon(BaseOptimizer): |
220 | 424 | r"""Adaptive Muon optimizer. |
221 | 425 |
|
|
0 commit comments