Skip to content

Commit 71bb0fd

Browse files
FDecaYedyanringskywko3n1gBoxiangW
authored andcommitted
Add muon and layerwise distributed optimizer (NVIDIA#2241)
Signed-off-by: Boxiang Wang <boxiangw@nvidia.com> Signed-off-by: Deyu Fu <deyuf@nvidia.com> Signed-off-by: Hao Wu <skyw@nvidia.com> Co-authored-by: Zijie Yan <zijiey@nvidia.com> Co-authored-by: Hao Wu <skyw@nvidia.com> Co-authored-by: oliver könig <okoenig@nvidia.com> Co-authored-by: Boxiang Wang <boxiangw@nvidia.com> Co-authored-by: mikail <mkhona@nvidia.com> Co-authored-by: Philip Petrakian <ppetrakian@nvidia.com>
1 parent 8faf282 commit 71bb0fd

File tree

17 files changed

+2729
-248
lines changed

17 files changed

+2729
-248
lines changed
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
3+
import logging
4+
from typing import Callable, List, Optional
5+
6+
import torch
7+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
8+
9+
from megatron.core.dist_checkpointing.dict_utils import nested_values
10+
from megatron.core.dist_checkpointing.mapping import LocalNonpersistentObject, ShardedStateDict
11+
from megatron.core.process_groups_config import ProcessGroupCollection
12+
from megatron.core.utils import get_pg_rank, get_pg_size
13+
14+
from .clip_grads import count_zeros_fp32, get_grad_norm_fp32
15+
from .optimizer import (
16+
ChainedOptimizer,
17+
Float16OptimizerWithFloat16Params,
18+
FP32Optimizer,
19+
MegatronOptimizer,
20+
)
21+
from .optimizer_config import OptimizerConfig
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class LayerWiseDistributedOptimizer(ChainedOptimizer):
27+
"""Layer-wise distributed optimizer for Megatron-core models.
28+
29+
Experimental distributed optimizer wrapper that distributes weight to DP ranks by layer.
30+
Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW)
31+
When using, keep all megatron distributed-optimizer related options OFF.
32+
33+
How LayerWiseDistributedOptimizer work:
34+
1. weights are splited into lists and each rank only keep its shard in its optimizer
35+
2. Megatron DDP handle allreduce grad, note that each rank have full model and grad
36+
3. optimizer is already modified so only param belong to this DP rank is updated
37+
4. grad_norm and zero counting will reduce metrics globally in step function
38+
5. Do regular update with chained optimizers, modified optimizer only update shard
39+
6. allgather updated params to every rank
40+
"""
41+
42+
def __init__(
43+
self,
44+
optimizers: List[MegatronOptimizer],
45+
config: OptimizerConfig,
46+
pg_collection: Optional[ProcessGroupCollection] = None,
47+
init_state_fn_list: Optional[List[Callable]] = None,
48+
) -> None:
49+
"""
50+
Initialize LayerWiseDistributedOptimizer.
51+
52+
Args:
53+
optimizers: List of MegatronOptimizers.
54+
config: OptimizerConfig.
55+
pg_collection: ProcessGroupCollection.
56+
init_state_fn_list: List of init state functions.
57+
"""
58+
59+
self.pg_collection = pg_collection
60+
self.shard_params(optimizers)
61+
if init_state_fn_list:
62+
assert len(init_state_fn_list) == len(
63+
optimizers
64+
), "init_state_fn_list must be the same length as optimizers if provided"
65+
66+
# wrap optimizer after sharding to avoid unnecessary master weight creation
67+
# for higher precision, optimizers are wrapped with megatron already
68+
if config.bf16:
69+
# unwrap FP32 optimizer, possibly from reusing get_megatron_optimizer for adam
70+
for i in range(len(optimizers)):
71+
opt = optimizers[i]
72+
if isinstance(opt, Float16OptimizerWithFloat16Params):
73+
raise TypeError(
74+
'LayerWiseDistributedOptimizer received Float16 optimizer already.'
75+
)
76+
# unwrap FP32 optimizer from reusing get_megatron_optimizer for adam
77+
if isinstance(opt, FP32Optimizer):
78+
opt = opt.optimizer
79+
optimizers[i] = Float16OptimizerWithFloat16Params(
80+
opt, config, None, init_state_fn_list[i] if init_state_fn_list else None
81+
)
82+
83+
super().__init__(optimizers)
84+
85+
# TODO(kunlun, deyuf): potential future perf optimization
86+
# since allreduce is unchanged and handled by megatron DDP, they're already in
87+
# contiguous gbuf. So instead of shard param by layer randomly, we can shard by
88+
# buf range but keep some "extras" to keep boundary weight not sharded.
89+
# This way each rank do some duplicated work but allgather_v is no longer needed
90+
# All current distopt optimization can also be potentially applied
91+
92+
def shard_params(self, optimizers):
93+
"""Shard all params into lists by rank."""
94+
# list of parameter are sorted by numel and assigned to ranks in ping-pong style
95+
# example of 4 ranks and 10 parameters p0-p9 after sorting, then dp_cp_params_list will be
96+
# [[p0, p7, p8], [p1, p6, p9], [p2, p5], [p3, p4]]
97+
98+
# simplify when dp_cp group size is 1
99+
if get_pg_size(self.pg_collection.dp_cp) == 1:
100+
self.dp_cp_params_list = None
101+
self.expt_dp_params_list = None
102+
return
103+
104+
dp_cp_idx, expt_dp_idx = 0, 0
105+
dp_cp_size = get_pg_size(self.pg_collection.dp_cp)
106+
expt_dp_size = get_pg_size(self.pg_collection.expt_dp)
107+
# create ping-pong style loop so memory is more balanced
108+
dp_cp_loop = list(range(dp_cp_size)) + list(range(dp_cp_size))[::-1]
109+
expt_dp_loop = list(range(expt_dp_size)) + list(range(expt_dp_size))[::-1]
110+
self.dp_cp_params_list = [[] for _ in range(dp_cp_size)]
111+
self.expt_dp_params_list = [[] for _ in range(expt_dp_size)]
112+
# get all param groups
113+
param_groups = []
114+
for optimizer in optimizers:
115+
param_groups += optimizer.param_groups
116+
117+
# sort param in all groups by param numel and assign to each rank evenly
118+
param_list = []
119+
for group_index, group in enumerate(param_groups):
120+
for p in group["params"]:
121+
param_list.append((p, group_index))
122+
param_list.sort(key=lambda x: x[0].numel())
123+
param_groups_this_rank = [[] for g in param_groups]
124+
125+
# assign params to rank in ping-pong style loop
126+
for p, group_index in param_list:
127+
if param_groups[group_index].get("is_expert_parallel", False):
128+
if expt_dp_loop[expt_dp_idx] == get_pg_rank(self.pg_collection.expt_dp):
129+
param_groups_this_rank[group_index].append(p)
130+
self.expt_dp_params_list[expt_dp_loop[expt_dp_idx]].append(p)
131+
expt_dp_idx = (expt_dp_idx + 1) % len(expt_dp_loop)
132+
else:
133+
if dp_cp_loop[dp_cp_idx] == get_pg_rank(self.pg_collection.dp_cp):
134+
param_groups_this_rank[group_index].append(p)
135+
self.dp_cp_params_list[dp_cp_loop[dp_cp_idx]].append(p)
136+
dp_cp_idx = (dp_cp_idx + 1) % len(dp_cp_loop)
137+
138+
# now we modify the group to only handle local params
139+
for groups, params in zip(param_groups, param_groups_this_rank):
140+
groups["params"] = params
141+
142+
# simplify when expt_dp group size is 1 or expert parallel is off
143+
if expt_dp_size == 1 or len(self.expt_dp_params_list[0]) == 0:
144+
self.expt_dp_params_list = None
145+
146+
@torch.no_grad()
147+
def allgather_params(self) -> None:
148+
"""All-gather updated params from all ranks."""
149+
150+
# helper function to flatten local params, allgather, unflatten and copy to model params
151+
def _allgather_helper(params_list, group):
152+
# flatten this rank's params and create empty tensor output list
153+
device = params_list[0][0].device
154+
dtype = params_list[0][0].dtype
155+
rank = get_pg_rank(group)
156+
# for rank without params create empty tensor and participate in allgather
157+
src = (
158+
_flatten_dense_tensors(params_list[rank])
159+
if len(params_list[rank]) > 0
160+
else torch.empty(0, device=device, dtype=dtype)
161+
)
162+
output_list = [
163+
torch.empty(sum([p.numel() for p in params]), device=device, dtype=dtype)
164+
for params in params_list
165+
]
166+
# single all_gather_v to collect all updated params
167+
torch.distributed.all_gather(output_list, src, group=group)
168+
# unflatten and copy gathered params for each rank i
169+
for idx, (flat_params, params) in enumerate(zip(output_list, params_list)):
170+
# skip local params and empty tensors
171+
if len(params) == 0 or idx == rank:
172+
continue
173+
updated_params = _unflatten_dense_tensors(flat_params, params)
174+
for updated_p, model_p in zip(updated_params, params):
175+
model_p.data.copy_(updated_p)
176+
177+
if self.pg_collection is None:
178+
return
179+
if self.dp_cp_params_list:
180+
_allgather_helper(self.dp_cp_params_list, self.pg_collection.dp_cp)
181+
if self.expt_dp_params_list:
182+
_allgather_helper(self.expt_dp_params_list, self.pg_collection.expt_dp)
183+
184+
@torch.no_grad()
185+
def broadcast_params(self):
186+
"""All rank broadcast updated local params."""
187+
# Broadcast linear layer weights to all other ranks. Kept as reference test.
188+
if self.dp_cp_params_list is None:
189+
return
190+
for i, params in enumerate(self.dp_cp_params_list):
191+
src_global_rank = torch.distributed.get_global_rank(self.pg_collection.dp_cp, i)
192+
for p in params:
193+
torch.distributed.broadcast(p, src_global_rank, self.pg_collection.dp_cp)
194+
if self.expt_dp_params_list is None:
195+
return
196+
for i, params in enumerate(self.expt_dp_params_list):
197+
src_global_rank = torch.distributed.get_global_rank(self.pg_collection.expt_dp, i)
198+
for p in params:
199+
torch.distributed.broadcast(p, src_global_rank, self.pg_collection.expt_dp)
200+
201+
@torch.no_grad()
202+
def get_grad_norm(self):
203+
# similar to dist opt, always aggregate globally
204+
grads_for_norm = []
205+
for optimizer in self.chained_optimizers:
206+
grads_for_norm += optimizer.get_main_grads_for_grad_norm()
207+
grad_norm = get_grad_norm_fp32(grads_for_norm, grad_stats_parallel_group=None)
208+
return grad_norm
209+
210+
@torch.no_grad()
211+
def count_zeros(self):
212+
params = []
213+
for optimizer in self.chained_optimizers:
214+
params += optimizer.get_parameters()
215+
return count_zeros_fp32(
216+
params,
217+
grad_stats_parallel_group=None,
218+
use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8,
219+
)
220+
221+
@torch.no_grad()
222+
def step(self): # type: ignore[no-untyped-def]
223+
"""step function for layer-wise optimizer."""
224+
update_successful, grad_norm, num_zeros_in_grad = super().step()
225+
226+
# All gather updated params.
227+
self.allgather_params()
228+
229+
return update_successful, grad_norm, num_zeros_in_grad
230+
231+
# TODO(deyuf): need to improve dist checkpointing design to properly handle this
232+
# fp32_from_fp16_params is list, each sub list could be empty if group is empty
233+
# this breaks dist checkpointing assumption since extract_sharded_base drop list structure
234+
# for now, we convert it to dict with index as key and convert back in load_state_dict
235+
def load_state_dict(self, state_dict):
236+
if len(self.chained_optimizers) == 1:
237+
wrapped_state_dict = {1: state_dict}
238+
else:
239+
wrapped_state_dict = state_dict
240+
for sd in wrapped_state_dict.values():
241+
if 'fp32_from_fp16_params' in sd and isinstance(sd['fp32_from_fp16_params'], dict):
242+
logger.info('[layerwise] converting fp32_from_fp16_params from dict to list')
243+
sd['fp32_from_fp16_params'] = [
244+
v for k, v in sorted(sd['fp32_from_fp16_params'].items())
245+
]
246+
super().load_state_dict(state_dict)
247+
248+
def sharded_state_dict(
249+
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False, **kwargs
250+
):
251+
"""
252+
Sharded state dict for torch_dist format checkpointing.
253+
For fixed DP usage only, set replica_id to 0 for all ShardedTensor.
254+
"""
255+
sharded_state_dict = super().sharded_state_dict(
256+
model_sharded_state_dict, is_loading, **kwargs
257+
)
258+
259+
# for fixed DP usage only
260+
for sh_base in nested_values(sharded_state_dict):
261+
if hasattr(sh_base, 'replica_id'):
262+
assert (
263+
isinstance(sh_base.replica_id, int) or len(sh_base.replica_id) == 3
264+
), f'Expected replica_id as int or (PP, TP, DP), got: {sh_base}'
265+
sh_base.replica_id = (
266+
0 if isinstance(sh_base.replica_id, int) else (*sh_base.replica_id[:2], 0)
267+
)
268+
269+
# later code assume list but chained optimizer fallback to non-list if there's only one
270+
if len(self.chained_optimizers) == 1:
271+
wrapped_sharded_state_dict = {1: sharded_state_dict}
272+
else:
273+
wrapped_sharded_state_dict = sharded_state_dict
274+
275+
# Adjust dict rank 0 output correct global metadata into common_dict
276+
for sd in wrapped_sharded_state_dict.values():
277+
# wrap empty containers into LocalNonpersistentObject so it won't be saved/loaded
278+
# params is already wrapped, we only need to handle fp32_from_fp16_params and state
279+
# more details in load_state_dict comment
280+
if 'fp32_from_fp16_params' in sd:
281+
sd['fp32_from_fp16_params'][:] = [
282+
group if group else LocalNonpersistentObject(group)
283+
for group in sd['fp32_from_fp16_params']
284+
]
285+
sd['fp32_from_fp16_params'] = {
286+
i: v for i, v in enumerate(sd['fp32_from_fp16_params'])
287+
}
288+
# state is a single dict and will be empty if optimizer is fully empty
289+
if not sd['optimizer']['state']:
290+
sd['optimizer']['state'] = LocalNonpersistentObject(sd['optimizer']['state'])
291+
# group keys(e.g. 'step') might be missing or not updated
292+
for i, group in enumerate(sd['optimizer']['param_groups']):
293+
# keep local param tensor so we only gather metadata
294+
local_params = group.pop('params')
295+
# save whether this group is empty, so we can use non-empty rank for metadata
296+
group['params'] = bool(local_params.unwrap())
297+
all_rank_groups = [None for _ in range(torch.distributed.get_world_size())]
298+
torch.distributed.all_gather_object(all_rank_groups, group)
299+
# find first non-empty group if it exists
300+
nonempty_rank_group = next((g for g in all_rank_groups if g['params']), group)
301+
nonempty_rank_group['params'] = local_params
302+
sd['optimizer']['param_groups'][i] = nonempty_rank_group
303+
return sharded_state_dict
304+
305+
def save_state_dict_to_file(self, filename: str) -> None:
306+
"""Save the parameter state of the optimizer. For torch format only.
307+
Args:
308+
filename: The filename to save the parameter state.
309+
"""
310+
torch.save(super().state_dict(), filename)
311+
312+
def load_state_dict_from_file(self, filename: str) -> None:
313+
"""Load the parameter state of the optimizer. For torch format only."""
314+
super().load_state_dict(torch.load(filename))

0 commit comments

Comments
 (0)