We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c959c6b commit 548b710Copy full SHA for 548b710
plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py
@@ -14,6 +14,7 @@
14
15
# Standard
16
from typing import Dict, Tuple
17
+import os
18
19
# Third Party
20
from fms_acceleration import AccelerationPlugin
@@ -69,7 +70,8 @@ def augmentation(
69
70
rank, world_size = 0, 1
71
if torch.distributed.is_initialized():
72
world_size = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
73
+ # we do not need to use the fallback as this is wrapped in an `is_initialized` block
74
+ rank = torch.distributed.get_node_local_rank()
75
76
if not hasattr(model.config, "name_or_path") or not model.config.name_or_path:
77
raise ValueError(
0 commit comments