Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xtuner/v1/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .fsdp import FSDPConfig
from .generate import GenerateConfig
from .optim import AdamWConfig, LRConfig, OptimConfig
from .optim import AdamWConfig, LRConfig, MuonConfig, OptimConfig


__all__ = [
Expand All @@ -9,4 +9,5 @@
"AdamWConfig",
"LRConfig",
"GenerateConfig",
"MuonConfig",
]
97 changes: 96 additions & 1 deletion xtuner/v1/config/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
from typing import Literal, Optional, Tuple

import torch
import torch.distributed as dist
from cyclopts import Parameter
from pydantic import BaseModel, ConfigDict
from typing_extensions import Annotated

from xtuner.v1.optim import Muon
from xtuner.v1.utils import get_logger


logger = get_logger()


class OptimConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
Expand All @@ -26,12 +33,100 @@ class AdamWConfig(OptimConfig):
eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Adam optimizer")] = 1e-8
foreach: Annotated[Optional[bool], Parameter(help="Use foreach implementation for AdamW")] = None

def build(self, params):
def build(self, model):
params = [p for p in model.parameters() if p.requires_grad]

trainable_parameters_names = model.trainable_parameters()
trainable_names = [name for name, _ in trainable_parameters_names]
untrainable_names = []
num_total_requires_grad = 0
num_total = 0
for name, params_ in model.named_parameters():
num_total += params_.numel()
num_total_requires_grad += params_.numel() if name in trainable_names else 0
if name not in trainable_names:
untrainable_names.append(name)

if dist.get_rank() == 0:
logger.info(
f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
)
logger.info(f"Untrainable parameters names: {untrainable_names}")
return torch.optim.AdamW(
params, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, foreach=self.foreach
)


class MuonConfig(OptimConfig):
weight_decay: Annotated[float, Parameter(help="Weight decay coefficient for L2 regularization")] = 0.1
momentum: Annotated[float, Parameter(help="Momentum coefficients for Muon optimizer")] = 0.95
betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for AdamW optimizer")] = (0.9, 0.95)
eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Muon optimizer")] = 1e-8

def build(self, model):
trainable_parameters_names = model.trainable_parameters()
trainable_names = {name for name, _ in trainable_parameters_names}

untrainable_names = []
num_total = 0
num_total_requires_grad = 0
num_muon = 0
num_adamw = 0

for name, p in model.named_parameters():
n = p.numel()
num_total += n
if name in trainable_names:
num_total_requires_grad += n
is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
if is_muon_tensor:
num_muon += n
else:
num_adamw += n
else:
untrainable_names.append(name)

muon_params = [
p
for name, p in model.named_parameters()
if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
]
adamw_params = [
p
for name, p in model.named_parameters()
if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name)
]
param_groups = [
dict(params=muon_params),
dict(params=adamw_params, algorithm="adamw"),
]

if dist.get_rank() == 0:
logger.info(
f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
)
logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)")
logger.info(f"Untrainable parameters names: {untrainable_names}")
logger.info(
f"using Muon optimizer distributed_mesh_size: {model.fsdp_mesh.size()}, "
f"distributed_mesh: {model.fsdp_mesh}"
)

optimizer = Muon(
param_groups,
distributed_mesh=model.fsdp_mesh, # TODO: 暂不支持 EP>1
lr=self.lr,
mu=self.momentum,
betas=self.betas,
weight_decay=self.weight_decay,
nesterov=True,
adjust_lr="rms_norm",
use_triton=False,
epsilon=self.eps,
)
return optimizer


class LRConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
lr_type: Annotated[Literal["cosine", "linear", "constant"], Parameter(help="Type of learning rate schedule")] = (
Expand Down
20 changes: 1 addition & 19 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,7 @@ def build_model(self) -> BaseModel:
return model

def build_optimizer(self, optim_cfg: OptimConfig) -> torch.optim.Optimizer:
params = [p for p in self.model.parameters() if p.requires_grad]

trainable_parameters_names = self.model.trainable_parameters()
trainable_names = [name for name, _ in trainable_parameters_names]
untrainable_names = []
num_total_requires_grad = 0
num_total = 0
for name, params_ in self.model.named_parameters():
num_total += params_.numel()
num_total_requires_grad += params_.numel() if name in trainable_names else 0
if name not in trainable_names:
untrainable_names.append(name)

if dist.get_rank() == 0:
logger.info(
f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
)
logger.info(f"Untrainable parameters names: {untrainable_names}")
return optim_cfg.build(params)
return optim_cfg.build(self.model)

@property
def data_replicate_size(self) -> int:
Expand Down
4 changes: 4 additions & 0 deletions xtuner/v1/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .muon import Muon


__all__ = ["Muon"]
Loading
Loading