|
1 | 1 | import math |
2 | 2 | import os |
3 | | -from typing import List, Optional |
| 3 | +from typing import List, Optional, Tuple |
4 | 4 |
|
5 | 5 | import torch |
6 | 6 | from torch.distributed import ReduceOp, all_reduce |
@@ -131,9 +131,18 @@ def reset(self): |
131 | 131 | state['moment2'] = torch.zeros_like(p) |
132 | 132 |
|
133 | 133 | @staticmethod |
134 | | - def adjust_lr_for_muon(lr: float, param_shape) -> float: |
135 | | - adjusted_ratio: float = 0.2 * math.sqrt(max(param_shape[0], param_shape[1])) |
136 | | - return lr * adjusted_ratio |
| 134 | + def get_adjusted_lr(lr: float, param_shape: Tuple[float, ...], use_adjusted_lr: bool = False) -> float: |
| 135 | + r"""Get the adjust learning rate.""" |
| 136 | + output_shape, *input_shape = param_shape |
| 137 | + input_shape = math.prod(input_shape) |
| 138 | + |
| 139 | + ratio: float = ( |
| 140 | + math.pow(max(1.0, output_shape / input_shape), 0.5) |
| 141 | + if use_adjusted_lr |
| 142 | + else 0.2 * math.sqrt(max(output_shape, input_shape)) |
| 143 | + ) |
| 144 | + |
| 145 | + return lr * ratio |
137 | 146 |
|
138 | 147 | @torch.no_grad() |
139 | 148 | def step(self, closure: CLOSURE = None) -> LOSS: |
@@ -202,9 +211,9 @@ def step(self, closure: CLOSURE = None) -> LOSS: |
202 | 211 | fixed_decay=False, |
203 | 212 | ) |
204 | 213 |
|
205 | | - lr: float = self.adjust_lr_for_muon(group['lr'], p.size()) if group['use_adjusted_lr'] else group['lr'] |
| 214 | + lr: float = self.get_adjusted_lr(group['lr'], p.size(), group['use_adjusted_lr']) |
206 | 215 |
|
207 | | - p.add_(g, alpha=-lr * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)) |
| 216 | + p.add_(g, alpha=-lr) |
208 | 217 | curr_idx += p.numel() |
209 | 218 |
|
210 | 219 | params = [p for p in group['params'] if p.grad is not None and not self.state[p]['use_muon']] |
|
0 commit comments