Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 72 additions & 4 deletions aiak_megatron/megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch.optim import SGD as CPUSGD
from torch.optim import AdamW as CPUAdam
from .muon import Muon

try:
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
Expand Down Expand Up @@ -50,6 +51,8 @@ def _get_param_groups(
min_lr: float,
decoupled_lr: Optional[float],
decoupled_min_lr: Optional[float],
muon_matched_adamw_rms: Optional[float],
use_muon: bool = False,
) -> List[Dict]:
"""Create parameter groups for optimizer.

Expand All @@ -71,6 +74,8 @@ def _get_param_groups(
min_lr (float): minimum learning rate.
decoupled_lr (Optional[float]): optional decoupled learning rate.
decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate.
muon_matched_adamw_rms (Optional[float]): The RMS of the matched AdamW's, typically 0.2 ~ 0.4
use_muon (bool): Whether to use Muon to create parameter groups.

Returns:
List of parameter groups.
Expand All @@ -80,6 +85,7 @@ def _get_param_groups(

# Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
params_map = {}
muon_params_map = {}
for model_chunk in model_chunks:
if model_chunk.ddp_config.use_custom_fsdp:
named_parameters = model_chunk.optimizer_named_parameters()
Expand Down Expand Up @@ -120,10 +126,22 @@ def _get_param_groups(
):
is_decoupled_lr = True

key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr)
if key not in params_map:
params_map[key] = []
params_map[key].append(param)
# check if linear params
bias_flag = name.endswith(".bias")
shape_flag = param.dim() == 2
embedding_flag = "embedding" in name or "output_layer" in name
muon_flag = use_muon and shape_flag \
and (not bias_flag) and (not embedding_flag)
if muon_flag:
key = (wd_mult, _lr_mult, is_expert_parallel)
if key not in muon_params_map:
muon_params_map[key] = []
muon_params_map[key].append(param)
else:
key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr)
if key not in params_map:
params_map[key] = []
params_map[key].append(param)

param_groups = []
for (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items():
Expand All @@ -145,6 +163,21 @@ def _get_param_groups(
decoupled_min_lr=decoupled_min_lr,
)

for (wd_mult, _lr_mult, is_expert_parallel), params in muon_params_map.items():
if len(params) == 0:
continue
param_groups.append(
{
'params': params,
'wd_mult': wd_mult,
'lr_mult': _lr_mult,
'is_expert_parallel': is_expert_parallel,
'use_muon': True,
'is_decoupled_lr': False,
}
)


return param_groups


Expand Down Expand Up @@ -227,6 +260,8 @@ def _get_param_groups_and_buffers(
min_lr=config.min_lr,
decoupled_lr=config.decoupled_lr,
decoupled_min_lr=config.decoupled_min_lr,
muon_matched_adamw_rms=config.muon_matched_adamw_rms,
use_muon=config.optimizer == 'muon',
)
param_groups = list(filter(filter_fn, param_groups))
buffers = {}
Expand Down Expand Up @@ -291,6 +326,19 @@ def _get_megatron_optimizer_based_on_param_groups(
bias_correction=True,
fused=True, # this flag is used to improve the performance of the cpu optimizer
)
elif config.optimizer == 'muon':
gpu_optimizer_cls = Muon
cpu_optimizer_cls = Muon
optimizer_defaults = dict(
lr=config.lr,
weight_decay=config.weight_decay,
matched_adamw_rms=config.muon_matched_adamw_rms,
momentum=config.muon_momentum,
nesterov=config.muon_nesterov,
ns_steps=config.muon_ns_steps,
adamw_betas=(config.adam_beta1, config.adam_beta2),
adamw_eps=config.adam_eps
)
else:
gpu_optimizer_cls = SGD
cpu_optimizer_cls = CPUSGD
Expand Down Expand Up @@ -362,6 +410,26 @@ def init_state_fn(opt, config=None):
momentum=config.sgd_momentum,
)
init_state_fn = None

elif config.optimizer == 'muon':
optimizer = Muon(param_groups,
lr=config.lr, weight_decay=config.weight_decay,
matched_adamw_rms=config.muon_matched_adamw_rms,
momentum=config.muon_momentum,
nesterov=config.muon_nesterov,
ns_steps=config.muon_ns_steps,
adamw_betas=(config.adam_beta1, config.adam_beta2),
adamw_eps=config.adam_eps)

def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from collections import defaultdict
from typing import Dict
from ..muon import Muon

import torch

Expand Down Expand Up @@ -173,7 +174,10 @@ def step(self, closure=None):
d2h_event = self._cpu_optimizer_map_data_event.pop(cpu_optimizer, None)
if d2h_event is not None:
d2h_event.synchronize()
cpu_optimizer.step(closure)
if isinstance(cpu_optimizer, Muon):
cpu_optimizer.step(self.cpu_copys_map_gpu_param)
else:
cpu_optimizer.step(closure)

# Sync state and param_groups to HDO after each step.
# NOTE: It is possible for the optimizer to change the properties
Expand Down
94 changes: 87 additions & 7 deletions aiak_megatron/megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from .grad_scaler import MegatronGradScaler
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
from .optimizer_config import OptimizerConfig
from .muon import Muon, MuonDistMeta
from megatron.core.parallel_state import get_tensor_model_parallel_group

logger = getLogger(__name__)

Expand Down Expand Up @@ -136,6 +138,7 @@ def _build_model_gbuf_param_range_map(
sub_param_start = max(0, gbuf_world_range.start - param_world_start)
sub_param_range = param_local_range.normalize(sub_param_start)
param_range_map[param] = {
"world_indexes": (param_world_start, param_world_end),
"gbuf_world": param_world_range,
"gbuf_world_in_bucket": param_world_range_in_bucket,
"gbuf_local": param_local_range,
Expand Down Expand Up @@ -322,14 +325,26 @@ def _build_model_and_main_param_groups(
shard_float16_groups.append(shard_float16_params_this_group)
shard_fp32_groups.append(shard_fp32_params_this_group)
shard_fp32_from_float16_groups.append(shard_fp32_from_float16_params_this_group)

dist_metas = {}

for model_param in group_range["params"]:

assert model_param.requires_grad

gbuf_index, dtype, bucket_index = param_gbuf_map[model_param]
gbuf_range = gbuf_ranges[gbuf_index][dtype][bucket_index]
param_range = gbuf_range["param_map"][model_param]["param"]
param_gbuf_ranges = gbuf_range["param_map"][model_param]
param_range = param_gbuf_ranges["param"]

if config.optimizer == "muon":
# gen dist meta
param_world_indexes = param_gbuf_ranges["world_indexes"]
# To get the tensor parallel split dim
tp_split_dim = getattr(model_param, 'partition_dim') if getattr(
model_param, 'tensor_model_parallel', False) else -1
dist_meta = MuonDistMeta(
gbuf_index, bucket_index, model_param.shape, param_world_indexes, tp_split_dim)

# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
Expand Down Expand Up @@ -391,6 +406,14 @@ def _build_model_and_main_param_groups(
model_float16_params_this_group.append(model_param)
shard_float16_params_this_group.append(shard_model_param)
shard_fp32_from_float16_params_this_group.append(shard_main_param)

# add to dist metas
if config.optimizer == "muon":
if config.use_precision_aware_optimizer_no_fp8_or_ds_fp8:
dist_metas[shard_model_param] = dist_meta
else:
dist_metas[shard_main_param] = dist_meta


# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
Expand Down Expand Up @@ -430,7 +453,7 @@ def _build_model_and_main_param_groups(
shard_float16_groups,
shard_fp32_groups,
shard_fp32_from_float16_groups,
)
), dist_metas

def __init__(
self,
Expand Down Expand Up @@ -487,8 +510,9 @@ def __init__(
assert self.ddp_config == model_chunk.ddp_config
self.distributed_optimizer_instance_id = distributed_optimizer_instance_id

assert isinstance(optimizer, (Adam, HybridDeviceOptimizer)) or optimizer is None, (
"Only Adam and HybridDeviceOptimizer currently supported, due to checkpointing requirements."
assert isinstance(optimizer, (Adam, HybridDeviceOptimizer, Muon)) or optimizer is None, (
"Only Adam / HybridDeviceOptimizer / Muon currently supported, "
"due to checkpointing requirements."
)

# when freezing sub-models we have no real optimizer
Expand Down Expand Up @@ -567,17 +591,40 @@ def __init__(
self.shard_float16_groups,
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups,
) = self._build_model_and_main_param_groups(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges, config
), dist_metas = self._build_model_and_main_param_groups(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges, config
)

if isinstance(self.optimizer, HybridDeviceOptimizer):
self.optimizer = HybridDeviceOptimizer(
params=[g["orig_group"] for g in self.opt_group_ranges], **self.optimizer.defaults
params=[g["orig_group"] for g in self.opt_group_ranges],
**self.optimizer.defaults
)
if config.optimizer == 'muon':
assert all(grad_buffer.grad_dtype == torch.float32 for grad_buffer in self.buffers), \
"all grad buffer should only contains float32 type for muon optimizer"
gbuf_sizes = [ [(bucket.grad_data.numel(), bucket.offset) for bucket in buffer.buckets ]
for buffer in self.buffers ]
self.optimizer.cpu_optimizers[0].enable_distributed_mode(
gbuf_sizes, self.data_parallel_group,
get_tensor_model_parallel_group(),
dist_metas,
)
else:
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())

# for muon optimizer, enable distributed mode
if isinstance(self.optimizer, Muon):
assert all(grad_buffer.grad_dtype == torch.float32 for grad_buffer in self.buffers), \
"all grad buffer should only contains float32 type for muon optimizer"
gbuf_sizes = [[(bucket.grad_data.numel(), bucket.offset) for bucket in buffer.buckets]
for buffer in self.buffers ]
self.optimizer.enable_distributed_mode(
gbuf_sizes, self.data_parallel_group,
get_tensor_model_parallel_group(),
dist_metas,
)

def _get_model_param_range_map(self, param: torch.nn.Parameter):
"""
Expand Down Expand Up @@ -733,6 +780,15 @@ def load_state_dict(self, state_dict):
"exp_avg": init_shard(self.config.exp_avg_dtype),
"exp_avg_sq": init_shard(self.config.exp_avg_sq_dtype),
}
# for muon optimizer link state ( for load state dict )
if self.config.optimizer == "muon":
use_muon = self.optimizer.param_groups[group_index].get("use_muon", False)
if use_muon:
tensors["muon_buffer"] = tensors.pop("exp_avg")
tensors.pop("exp_avg_sq", None)
else:
tensors["adamw_exp_avg"] = tensors.pop("exp_avg")
tensors["adamw_exp_avg_sq"] = tensors.pop("exp_avg_sq")
if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8:
if self.config.store_param_remainders and self.config.bf16:
tensors["master_param"] = init_shard(torch.int16)
Expand Down Expand Up @@ -845,6 +901,16 @@ def _get_main_param_and_optimizer_states(self, model_param):
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"param": main_param, **optim_state}

# process muon to be compatiable with adam ( always save to exp_avg / exp_avg_sq )
if isinstance(self.optimizer, Muon):
use_muon = self.optimizer.param_groups[group_index].get("use_muon", False)
if use_muon:
tensors["exp_avg"] = tensors["muon_buffer"]
tensors["exp_avg_sq"] = torch.zeros_like(tensors["param"])
else:
tensors["exp_avg"] = tensors["adamw_exp_avg"]
tensors["exp_avg_sq"] = tensors["adamw_exp_avg_sq"]
return tensors

def _set_main_param_and_optimizer_states(self, model_param, tensors):
Expand Down Expand Up @@ -874,6 +940,17 @@ def _set_main_param_and_optimizer_states(self, model_param, tensors):
return

group_index, group_order = self.model_param_group_index_map[model_param]

# for muon optimizer link state ( for load state dict )
if self.config.optimizer == "muon":
use_muon = self.optimizer.param_groups[group_index].get("use_muon", False)
if use_muon:
tensors["muon_buffer"] = tensors.pop("exp_avg")
tensors.pop("exp_avg_sq", None)
else:
tensors["adamw_exp_avg"] = tensors.pop("exp_avg")
tensors["adamw_exp_avg_sq"] = tensors.pop("exp_avg_sq")

if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
for k, v in tensors.items():
Expand All @@ -892,6 +969,9 @@ def _set_main_param_and_optimizer_states(self, model_param, tensors):
optim_state = self.optimizer.state[main_param]
dst_tensors = {"param": main_param, **optim_state}
for key in dst_tensors:
# since muon optimizer does not have exp_avg and exp_avg_sq (we pop them before)
if self.config.optimizer == "muon" and not key in tensors:
continue
dst_tensors[key].copy_(tensors[key])

def get_parameter_state_fs_bucket_space(self):
Expand Down
Loading