Skip to content

Commit dfe4da2

Browse files
authored
Update tp support in muon (#2385)
Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent aee4a74 commit dfe4da2

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

megatron/core/optimizer/muon.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
from torch.optim.optimizer import ParamsT
1010

11-
from megatron.core import parallel_state
1211
from megatron.core.process_groups_config import ProcessGroupCollection
1312
from megatron.core.transformer.module import MegatronModule
1413
from megatron.core.utils import get_pg_size, log_single_rank
@@ -76,7 +75,7 @@ def scaled_orthogonalize_fn(
7675
f'{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}',
7776
)
7877
size = [grad.size(-2), grad.size(-1)]
79-
if partition_dim:
78+
if partition_dim is not None:
8079
size[partition_dim] *= get_pg_size(tp_group)
8180
orth_grad = newton_schulz_tp(
8281
grad,
@@ -130,8 +129,7 @@ def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> t
130129
tp_group = None
131130
partition_dim = None if self.mode == "blockwise" else getattr(p, "partition_dim", None)
132131
if partition_dim == -1:
133-
# llm-shower use different default value for partition_dim than TE.
134-
# Because -1 is a valid index for ndarray, we decided to not overload it.
132+
# emerging-optimizers use None instead of -1 to indicate no tensor parallel
135133
partition_dim = None
136134

137135
if self.split_qkv and self.is_qkv_fn(p): # type: ignore[misc]
@@ -201,8 +199,6 @@ def get_megatron_muon_optimizer(
201199
# before this function receive properly created collection
202200
if pg_collection is None:
203201
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
204-
pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True)
205-
pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group()
206202

207203
log_single_rank(logger, logging.INFO, f'Setting up emerging optimizer with config {config}')
208204

0 commit comments

Comments
 (0)