Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 284 additions & 0 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"Optimizer",
"LionW",
"AdamW",
"MuonW",
"Scheduler",
"CosWithWarmup",
"LinearWithWarmup",
Expand Down Expand Up @@ -646,6 +647,289 @@ def get_post_step_metrics(
self._step_size_maxs = None
return metrics

# MuonW optimizer - experimental implementation
# Can be instantiated directly without config changes:
# optimizer = MuonW(model.parameters(), lr=0.02)


class MuonAdamW(Optimizer):
"""
Muon optimizer with AdamW fallback for non-matrix parameters.

Muon (Momentum Orthogonalized by Newton-schulz) is an optimizer that uses
orthogonalization of updates for matrix parameters to achieve better
conditioning. For non-matrix parameters and certain excluded layers
(embeddings, output heads), it falls back to AdamW.

Reference:
"Muon: Momentum Orthogonalized by Newton-schulz"
https://github.com/KellerJordan/Muon

Args:
params: Iterable of parameters to optimize or dicts defining parameter groups
lr: Learning rate (default: 0.01)
betas: Coefficients for computing running averages (default: (0.95, 0.95))
weight_decay: Weight decay coefficient (default: 0.0)
ns_steps: Number of Newton-Schulz iterations (default: 5)
nesterov: Whether to use Nesterov momentum (default: True)
eps: Term added to denominator for AdamW (default: 1e-8)
record_update_metrics: Whether to record update metrics (default: False)
selective_updates: Whether to use selective weight updates (default: False)

Note:
- Matrix parameters (2D+) use Muon unless they contain 'embed' or 'head' in name
- Non-matrix parameters always use AdamW
- Weight decay is applied AdamW-style (decoupled)
"""
def __init__(
self,
params,
lr=0.01,
betas=(0.95, 0.95), # Muon uses single momentum param
weight_decay=0.0,
ns_steps=5,
nesterov=True,
eps=1e-8, # For AdamW backup
record_update_metrics=False,
selective_updates=False,
device=None,
):
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
# User provided param groups
for param_group in params:
if 'use_muon' not in param_group:
param_group['use_muon'] = True
else:
# Convert single params list to a param group
params = [{'params': params, 'use_muon': True}]

defaults = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay,
ns_steps=ns_steps,
nesterov=nesterov,
eps=eps,
use_muon=True, # Default to using Muon
)
super().__init__(
params,
defaults,
record_update_metrics=record_update_metrics,
selective_updates=selective_updates
)
self._device = device
self._update_norms = None
self._update_maxs = None
self._update_param_names = None

def zeropower_via_newtonschulz5(self, G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
"""
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT

# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X

if G.size(-2) > G.size(-1):
X = X.mT
return X

def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
"""Return optimizer state for a parameter."""
state = self.state[param]
if not state:
return {}

result = {}
if 'momentum_buffer' in state:
result['momentum_buffer'] = state['momentum_buffer']
if 'exp_avg' in state:
result['exp_avg'] = state['exp_avg']
if 'exp_avg_sq' in state:
result['exp_avg_sq'] = state['exp_avg_sq']

return result

@torch.no_grad()
def step(self, closure=None):
"""Perform a single optimization step."""
if closure is not None:
with torch.enable_grad():
closure()

device = get_default_device() if self._device is None else self._device
update_norms = []
update_maxs = []
update_param_names = []

collecting_metrics = self._collecting_metrics and self._record_update_metrics

for group in self.param_groups:
lr = group['lr']
weight_decay = group['weight_decay']
beta1, beta2 = group['betas']
ns_steps = group['ns_steps']
nesterov = group['nesterov']
eps = group['eps']
use_muon = group['use_muon']

for name, p in zip(group["param_names"], group["params"]):
name = self._clean_param_name(name)

if p.grad is None:
if collecting_metrics:
update_param_names.append(name)
update_norms.append(torch.tensor([0.0], device=device))
update_maxs.append(torch.tensor([0.0], device=device))
continue

# Apply weight decay
#mask = p.grad != 0 if self._selective_updates else 1
mask = (p.grad != 0) if self._selective_updates else torch.ones_like(p, dtype=torch.bool)
p.mul_(1 - mask * (lr * weight_decay))

grad = p.grad
state = self.state[p]

# Determine whether to use Muon or AdamW for this parameter
# We use Muon for matrix parameters unless explicitly disabled
should_use_muon = use_muon and p.ndim >= 2 and not ('embed' in name.lower() or 'head' in name.lower())

if should_use_muon:
# --- Muon Update Logic ---

# Initialize momentum buffer if needed
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(grad)
momentum_buffer = state['momentum_buffer']

# Update momentum
momentum_buffer.lerp_(grad, mask * (1 - beta1))

# Compute update
if nesterov:
update = momentum_buffer * beta1 + grad * (1 - beta1)
else:
update = momentum_buffer.clone()

if isinstance(mask, torch.Tensor):
update.mul_(mask)

# Handle conv filters
orig_shape = update.shape
if update.ndim == 4:
update = update.view(update.shape[0], -1)

# Apply Newton-Schulz
update = self.zeropower_via_newtonschulz5(update, steps=ns_steps)

# Scale update
update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5

# Reshape if needed
if len(orig_shape) == 4:
update = update.view(orig_shape)

else:
# --- AdamW Update Logic ---

# Initialize momentum buffers if needed
if 'exp_avg' not in state:
state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad)
state['step'] = 0

# Update step count
state['step'] += 1
step = state['step']

# Update momentum buffers
state['exp_avg'].lerp_(grad, mask * (1 - beta1))
state['exp_avg_sq'].mul_(1 - mask * (1 - beta2)).addcmul_(grad, grad, value=1 - beta2)

# Bias correction
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step

# Compute AdamW update
denom = (state['exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(eps)
update = state['exp_avg'] / bias_correction1 / denom

if isinstance(mask, torch.Tensor):
update.mul_(mask)

# Apply update
p.add_(update, alpha=-lr)

# Collect metrics
if collecting_metrics:
update_param_names.append(name)
update_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32).unsqueeze(0))
update_maxs.append(update.abs().max().unsqueeze(0))

# Store metrics
if collecting_metrics:
self._update_norms = update_norms
self._update_maxs = update_maxs
self._update_param_names = update_param_names

return None

def get_post_step_metrics(
self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
) -> Dict[str, torch.Tensor]:
"""Get metrics about the optimization step."""
if not (self._record_update_metrics and self._collecting_metrics):
return {}

device = get_default_device() if self._device is None else self._device
dst_rank = 0
if process_group is not None:
dst_rank = dist.get_global_rank(process_group, 0)

param_names = self._update_param_names
update_norms = self._update_norms
update_maxs = self._update_maxs

if param_names is None or update_norms is None or update_maxs is None:
return {}

# Reduce metrics if needed
if is_distributed() and isinstance(module, FullyShardedDataParallel):
# Reduce norms
all_norms = torch.cat(update_norms).to(device) ** 2.0
dist.reduce(all_norms, dst_rank, op=dist.ReduceOp.SUM, group=process_group)
update_norms = (all_norms ** (0.5)).squeeze(0).split(1)

# Reduce maxs
all_maxs = torch.cat(update_maxs).to(device)
dist.reduce(all_maxs, dst_rank, op=dist.ReduceOp.MAX, group=process_group)
update_maxs = all_maxs.split(1)

# Collect metrics
metrics = {}
for param_name, update_norm, update_max in zip(param_names, update_norms, update_maxs):
metrics[f"update/{param_name}.norm"] = update_norm.squeeze(0)
metrics[f"update/{param_name}.max"] = update_max.squeeze(0)

# Reset stored metrics
self._update_norms = None
self._update_maxs = None
self._update_param_names = None

return metrics


@dataclass
class Scheduler(metaclass=ABCMeta):
Expand Down