Skip to content

Commit ed7821d

Browse files
committed
fix: compute device correctly
Signed-off-by: Mehant Kammakomati <[email protected]> Signed-off-by: Yu Chin Fabian Lim <[email protected]> Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent c959c6b commit ed7821d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def augmentation(
6969
rank, world_size = 0, 1
7070
if torch.distributed.is_initialized():
7171
world_size = torch.distributed.get_world_size()
72-
rank = torch.distributed.get_rank()
72+
# we do not need to use the fallback as this is wrapped in an `is_initialized` block
73+
rank = torch.distributed.get_node_local_rank()
7374

7475
if not hasattr(model.config, "name_or_path") or not model.config.name_or_path:
7576
raise ValueError(

0 commit comments

Comments
 (0)