Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions deepspeed/runtime/zero/muon/original_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,24 @@ def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
params = sorted(params, key=lambda x: x.size(), reverse=True)
super().__init__(params, defaults)

def load_state_dict(self, state_dict):
"""Load optimizer state dict and cast momentum_buffer to match parameter dtype.

When resuming from a checkpoint with bf16 enabled, momentum_buffer may be saved as fp32
while parameters are bf16. This override ensures momentum_buffer dtype matches the
parameter dtype to prevent dtype mismatch errors in muon_update.

See: https://github.com/deepspeedai/DeepSpeed/issues/7746
"""
super().load_state_dict(state_dict)
# Cast momentum_buffer to match parameter dtype after loading
for group in self.param_groups:
for p in group["params"]:
if p in self.state and "momentum_buffer" in self.state[p]:
buf = self.state[p]["momentum_buffer"]
if buf.dtype != p.dtype:
self.state[p]["momentum_buffer"] = buf.to(dtype=p.dtype)

@torch.no_grad()
def step(self, closure=None):

Expand Down Expand Up @@ -140,6 +158,19 @@ def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
super().__init__(params, defaults)

def load_state_dict(self, state_dict):
"""Load optimizer state dict and cast momentum_buffer to match parameter dtype.

See: https://github.com/deepspeedai/DeepSpeed/issues/7746
"""
super().load_state_dict(state_dict)
for group in self.param_groups:
for p in group["params"]:
if p in self.state and "momentum_buffer" in self.state[p]:
buf = self.state[p]["momentum_buffer"]
if buf.dtype != p.dtype:
self.state[p]["momentum_buffer"] = buf.to(dtype=p.dtype)

@torch.no_grad()
def step(self, closure=None):

Expand Down Expand Up @@ -218,6 +249,21 @@ def __init__(self, param_groups):
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
super().__init__(param_groups, dict())

def load_state_dict(self, state_dict):
"""Load optimizer state dict and cast buffers to match parameter dtype.

Handles both Muon buffers (momentum_buffer) and Adam buffers (exp_avg, exp_avg_sq).
See: https://github.com/deepspeedai/DeepSpeed/issues/7746
"""
super().load_state_dict(state_dict)
for group in self.param_groups:
for p in group["params"]:
if p in self.state:
state = self.state[p]
for key in ["momentum_buffer", "exp_avg", "exp_avg_sq"]:
if key in state and state[key].dtype != p.dtype:
state[key] = state[key].to(dtype=p.dtype)

@torch.no_grad()
def step(self, closure=None):

Expand Down Expand Up @@ -287,6 +333,21 @@ def __init__(self, param_groups):
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
super().__init__(param_groups, dict())

def load_state_dict(self, state_dict):
"""Load optimizer state dict and cast buffers to match parameter dtype.

Handles both Muon buffers (momentum_buffer) and Adam buffers (exp_avg, exp_avg_sq).
See: https://github.com/deepspeedai/DeepSpeed/issues/7746
"""
super().load_state_dict(state_dict)
for group in self.param_groups:
for p in group["params"]:
if p in self.state:
state = self.state[p]
for key in ["momentum_buffer", "exp_avg", "exp_avg_sq"]:
if key in state and state[key].dtype != p.dtype:
state[key] = state[key].to(dtype=p.dtype)

@torch.no_grad()
def step(self, closure=None):

Expand Down
72 changes: 72 additions & 0 deletions tests/unit/ops/muon/test_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,75 @@ def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer):
after_training = [p.clone().cpu() for p in model.parameters()]
for initial, final in zip(initial_params, after_training):
assert not torch.equal(initial.cpu(), final.cpu()), "Parameters should have been updated during training"


# Test configurations for bf16 checkpoint resume
# Tests fix for https://github.com/deepspeedai/DeepSpeed/issues/7746
bf16_checkpoint_configs = []
for zero_stage in [1, 2]:
bf16_checkpoint_configs.append([zero_stage])


@pytest.mark.parametrize('zero_stage', [1, 2])
class TestMuonBF16CheckpointResume(DistributedTest):
"""Test that Muon optimizer can resume training from checkpoint with bf16 enabled.

This tests the fix for issue #7746 where momentum_buffer dtype mismatch
caused crashes when resuming from checkpoint.
"""

def test(self, zero_stage, tmpdir):
if torch.bfloat16 not in get_accelerator().supported_dtypes():
pytest.skip("bf16 not supported on this accelerator")

hidden_dim = 64
nlayers = 3
batch_size = 4

config_dict = {
"train_batch_size": batch_size,
"optimizer": {
"type": "muon",
"params": {"lr": 0.02}
},
"bf16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage,
},
"zero_allow_untested_optimizer": True,
}

# Create model and train for a few steps to populate optimizer state
model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers)
engine, optimizer, _, _ = deepspeed.initialize(
config=config_dict,
model=model,
model_parameters=model.parameters(),
dist_init_required=False,
)

# Train for a few steps to create momentum_buffer state
for _ in range(3):
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.bfloat16)
y = torch.randint(0, hidden_dim, (batch_size,), device=engine.device)
loss = engine(x, y)
engine.backward(loss)
engine.step()

# Save checkpoint
ckpt_dir = str(tmpdir)
engine.save_checkpoint(ckpt_dir)

# Load checkpoint
engine.load_checkpoint(ckpt_dir)

# Resume training - this would fail before the fix due to dtype mismatch
# in momentum.lerp_(grad, 1 - beta) where momentum is fp32 and grad is bf16
for _ in range(3):
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.bfloat16)
y = torch.randint(0, hidden_dim, (batch_size,), device=engine.device)
loss = engine(x, y)
engine.backward(loss)
engine.step() # This should not raise dtype mismatch error
Loading