Skip to content

Commit 1bb2f8c

Browse files
committed
fix: compute device correctly
Signed-off-by: Mehant Kammakomati <[email protected]> Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent c959c6b commit 1bb2f8c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# Standard
1616
from typing import Dict, Tuple
17+
import os
1718

1819
# Third Party
1920
from fms_acceleration import AccelerationPlugin
@@ -69,7 +70,8 @@ def augmentation(
6970
rank, world_size = 0, 1
7071
if torch.distributed.is_initialized():
7172
world_size = torch.distributed.get_world_size()
72-
rank = torch.distributed.get_rank()
73+
# we do not need to use the fallback as this is wrapped in an `is_initialized` block
74+
rank = torch.distributed.get_node_local_rank()
7375

7476
if not hasattr(model.config, "name_or_path") or not model.config.name_or_path:
7577
raise ValueError(

0 commit comments

Comments
 (0)