From 28459460f5cc953ec1cc59b9a94620237359eea2 Mon Sep 17 00:00:00 2001 From: yurekami Date: Mon, 29 Dec 2025 07:16:19 +0900 Subject: [PATCH] Fix Muon optimizer checkpoint resume with bf16 mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When resuming training from a checkpoint with bf16 enabled, the Muon optimizer's momentum_buffer was loaded as fp32 (from the checkpoint) while gradients were bf16, causing a dtype mismatch error in the lerp_() operation. This fix adds a load_state_dict override to all Muon optimizer classes (Muon, SingleDeviceMuon, MuonWithAuxAdam, SingleDeviceMuonWithAuxAdam) that casts the momentum_buffer (and exp_avg/exp_avg_sq for hybrid classes) to match the parameter dtype after loading the checkpoint. Fixes: https://github.com/deepspeedai/DeepSpeed/issues/7746 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: yurekami --- deepspeed/runtime/zero/muon/original_muon.py | 61 +++++++++++++++++ tests/unit/ops/muon/test_muon.py | 72 ++++++++++++++++++++ 2 files changed, 133 insertions(+) diff --git a/deepspeed/runtime/zero/muon/original_muon.py b/deepspeed/runtime/zero/muon/original_muon.py index f4dc7a0909bb..6a0858f90401 100644 --- a/deepspeed/runtime/zero/muon/original_muon.py +++ b/deepspeed/runtime/zero/muon/original_muon.py @@ -101,6 +101,24 @@ def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): params = sorted(params, key=lambda x: x.size(), reverse=True) super().__init__(params, defaults) + def load_state_dict(self, state_dict): + """Load optimizer state dict and cast momentum_buffer to match parameter dtype. + + When resuming from a checkpoint with bf16 enabled, momentum_buffer may be saved as fp32 + while parameters are bf16. This override ensures momentum_buffer dtype matches the + parameter dtype to prevent dtype mismatch errors in muon_update. + + See: https://github.com/deepspeedai/DeepSpeed/issues/7746 + """ + super().load_state_dict(state_dict) + # Cast momentum_buffer to match parameter dtype after loading + for group in self.param_groups: + for p in group["params"]: + if p in self.state and "momentum_buffer" in self.state[p]: + buf = self.state[p]["momentum_buffer"] + if buf.dtype != p.dtype: + self.state[p]["momentum_buffer"] = buf.to(dtype=p.dtype) + @torch.no_grad() def step(self, closure=None): @@ -140,6 +158,19 @@ def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) super().__init__(params, defaults) + def load_state_dict(self, state_dict): + """Load optimizer state dict and cast momentum_buffer to match parameter dtype. + + See: https://github.com/deepspeedai/DeepSpeed/issues/7746 + """ + super().load_state_dict(state_dict) + for group in self.param_groups: + for p in group["params"]: + if p in self.state and "momentum_buffer" in self.state[p]: + buf = self.state[p]["momentum_buffer"] + if buf.dtype != p.dtype: + self.state[p]["momentum_buffer"] = buf.to(dtype=p.dtype) + @torch.no_grad() def step(self, closure=None): @@ -218,6 +249,21 @@ def __init__(self, param_groups): assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) super().__init__(param_groups, dict()) + def load_state_dict(self, state_dict): + """Load optimizer state dict and cast buffers to match parameter dtype. + + Handles both Muon buffers (momentum_buffer) and Adam buffers (exp_avg, exp_avg_sq). + See: https://github.com/deepspeedai/DeepSpeed/issues/7746 + """ + super().load_state_dict(state_dict) + for group in self.param_groups: + for p in group["params"]: + if p in self.state: + state = self.state[p] + for key in ["momentum_buffer", "exp_avg", "exp_avg_sq"]: + if key in state and state[key].dtype != p.dtype: + state[key] = state[key].to(dtype=p.dtype) + @torch.no_grad() def step(self, closure=None): @@ -287,6 +333,21 @@ def __init__(self, param_groups): assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) super().__init__(param_groups, dict()) + def load_state_dict(self, state_dict): + """Load optimizer state dict and cast buffers to match parameter dtype. + + Handles both Muon buffers (momentum_buffer) and Adam buffers (exp_avg, exp_avg_sq). + See: https://github.com/deepspeedai/DeepSpeed/issues/7746 + """ + super().load_state_dict(state_dict) + for group in self.param_groups: + for p in group["params"]: + if p in self.state: + state = self.state[p] + for key in ["momentum_buffer", "exp_avg", "exp_avg_sq"]: + if key in state and state[key].dtype != p.dtype: + state[key] = state[key].to(dtype=p.dtype) + @torch.no_grad() def step(self, closure=None): diff --git a/tests/unit/ops/muon/test_muon.py b/tests/unit/ops/muon/test_muon.py index f12cbb358a82..a9c74a15d7e1 100644 --- a/tests/unit/ops/muon/test_muon.py +++ b/tests/unit/ops/muon/test_muon.py @@ -72,3 +72,75 @@ def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer): after_training = [p.clone().cpu() for p in model.parameters()] for initial, final in zip(initial_params, after_training): assert not torch.equal(initial.cpu(), final.cpu()), "Parameters should have been updated during training" + + +# Test configurations for bf16 checkpoint resume +# Tests fix for https://github.com/deepspeedai/DeepSpeed/issues/7746 +bf16_checkpoint_configs = [] +for zero_stage in [1, 2]: + bf16_checkpoint_configs.append([zero_stage]) + + +@pytest.mark.parametrize('zero_stage', [1, 2]) +class TestMuonBF16CheckpointResume(DistributedTest): + """Test that Muon optimizer can resume training from checkpoint with bf16 enabled. + + This tests the fix for issue #7746 where momentum_buffer dtype mismatch + caused crashes when resuming from checkpoint. + """ + + def test(self, zero_stage, tmpdir): + if torch.bfloat16 not in get_accelerator().supported_dtypes(): + pytest.skip("bf16 not supported on this accelerator") + + hidden_dim = 64 + nlayers = 3 + batch_size = 4 + + config_dict = { + "train_batch_size": batch_size, + "optimizer": { + "type": "muon", + "params": {"lr": 0.02} + }, + "bf16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage, + }, + "zero_allow_untested_optimizer": True, + } + + # Create model and train for a few steps to populate optimizer state + model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers) + engine, optimizer, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + ) + + # Train for a few steps to create momentum_buffer state + for _ in range(3): + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.bfloat16) + y = torch.randint(0, hidden_dim, (batch_size,), device=engine.device) + loss = engine(x, y) + engine.backward(loss) + engine.step() + + # Save checkpoint + ckpt_dir = str(tmpdir) + engine.save_checkpoint(ckpt_dir) + + # Load checkpoint + engine.load_checkpoint(ckpt_dir) + + # Resume training - this would fail before the fix due to dtype mismatch + # in momentum.lerp_(grad, 1 - beta) where momentum is fp32 and grad is bf16 + for _ in range(3): + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.bfloat16) + y = torch.randint(0, hidden_dim, (batch_size,), device=engine.device) + loss = engine(x, y) + engine.backward(loss) + engine.step() # This should not raise dtype mismatch error