Skip to content

Commit 504b5f4

Browse files
committed
Add DeepEP fallback logic and tests
Signed-off-by: ooooo <[email protected]>
1 parent 1d42deb commit 504b5f4

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

nemo_automodel/components/moe/layers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import warnings
1515
from dataclasses import dataclass
1616
from functools import partial
1717
from typing import Literal, Optional
@@ -20,6 +20,7 @@
2020
import torch.nn as nn
2121
import torch.nn.functional as F
2222

23+
from nemo_automodel.components.distributed.init_utils import get_world_size_safe
2324
from nemo_automodel.components.moe.utils import BackendConfig, initialize_linear_module
2425
from nemo_automodel.shared.utils import dtype_from_str
2526

@@ -914,7 +915,15 @@ def __init__(self, config: MoEConfig, backend: BackendConfig):
914915
self.gate = FakeBalancedGate(config)
915916
else:
916917
self.gate = Gate(config, gate_precision=backend.gate_precision)
917-
if backend.enable_deepep:
918+
if backend.enable_deepep and get_world_size_safe() == 1:
919+
warnings.warn(
920+
"DeepEP is enabled in config, but world size is 1. "
921+
"DeepEP requires multiple GPUs. Falling back to standard GroupedExperts.",
922+
category=UserWarning,
923+
stacklevel=2,
924+
)
925+
self.experts = GroupedExperts(config)
926+
elif backend.enable_deepep:
918927
self.experts = GroupedExpertsDeepEP(config)
919928
else:
920929
self.experts = GroupedExperts(config)

tests/unit_tests/moe/test_layers.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,10 +1258,20 @@ def test_moe_init_with_fake_balanced_gate(self, moe_config, backend_config):
12581258
assert isinstance(moe.gate, FakeBalancedGate)
12591259
assert isinstance(moe.experts, GroupedExperts)
12601260

1261-
def test_moe_init_with_deepep(self, moe_config, backend_config):
1262-
"""Test MoE initialization with DeepEP."""
1261+
def test_moe_init_with_deepep_single_device(self, moe_config, backend_config):
1262+
"""DeepEP enabled but world size == 1 should fall back to GroupedExperts."""
12631263
backend_config.enable_deepep = True
1264-
moe = MoE(moe_config, backend_config)
1264+
with patch("nemo_automodel.components.moe.layers.get_world_size_safe", return_value=1):
1265+
moe = MoE(moe_config, backend_config)
1266+
1267+
assert isinstance(moe.gate, Gate)
1268+
assert isinstance(moe.experts, GroupedExperts)
1269+
1270+
def test_moe_init_with_deepep_multi_device(self, moe_config, backend_config):
1271+
"""DeepEP enabled and world size > 1 should use GroupedExpertsDeepEP."""
1272+
backend_config.enable_deepep = True
1273+
with patch("nemo_automodel.components.moe.layers.get_world_size_safe", return_value=2):
1274+
moe = MoE(moe_config, backend_config)
12651275

12661276
assert isinstance(moe.gate, Gate)
12671277
assert isinstance(moe.experts, GroupedExpertsDeepEP)

0 commit comments

Comments
 (0)