|
8 | 8 | import torch |
9 | 9 | from torch.optim.optimizer import ParamsT |
10 | 10 |
|
11 | | -from megatron.core import parallel_state |
12 | 11 | from megatron.core.process_groups_config import ProcessGroupCollection |
13 | 12 | from megatron.core.transformer.module import MegatronModule |
14 | 13 | from megatron.core.utils import get_pg_size, log_single_rank |
@@ -76,7 +75,7 @@ def scaled_orthogonalize_fn( |
76 | 75 | f'{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}', |
77 | 76 | ) |
78 | 77 | size = [grad.size(-2), grad.size(-1)] |
79 | | - if partition_dim: |
| 78 | + if partition_dim is not None: |
80 | 79 | size[partition_dim] *= get_pg_size(tp_group) |
81 | 80 | orth_grad = newton_schulz_tp( |
82 | 81 | grad, |
@@ -130,8 +129,7 @@ def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> t |
130 | 129 | tp_group = None |
131 | 130 | partition_dim = None if self.mode == "blockwise" else getattr(p, "partition_dim", None) |
132 | 131 | if partition_dim == -1: |
133 | | - # llm-shower use different default value for partition_dim than TE. |
134 | | - # Because -1 is a valid index for ndarray, we decided to not overload it. |
| 132 | + # emerging-optimizers use None instead of -1 to indicate no tensor parallel |
135 | 133 | partition_dim = None |
136 | 134 |
|
137 | 135 | if self.split_qkv and self.is_qkv_fn(p): # type: ignore[misc] |
@@ -201,8 +199,6 @@ def get_megatron_muon_optimizer( |
201 | 199 | # before this function receive properly created collection |
202 | 200 | if pg_collection is None: |
203 | 201 | pg_collection = ProcessGroupCollection.use_mpu_process_groups() |
204 | | - pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) |
205 | | - pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group() |
206 | 202 |
|
207 | 203 | log_single_rank(logger, logging.INFO, f'Setting up emerging optimizer with config {config}') |
208 | 204 |
|
|
0 commit comments