Skip to content

Commit 9bdedd0

Browse files
author
AlexandrByzov
committed
bugfix: add support for global_ordinal, local_ordinal, world_size in xla
1 parent 6675932 commit 9bdedd0

File tree

1 file changed

+15
-0
lines changed
  • src/lightning/fabric/plugins/environments

1 file changed

+15
-0
lines changed

src/lightning/fabric/plugins/environments/xla.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def world_size(self) -> int:
6666
The output is cached for performance.
6767
6868
"""
69+
if _XLA_GREATER_EQUAL_2_1:
70+
from torch_xla import runtime as xr
71+
72+
return xr.world_size()
73+
6974
import torch_xla.core.xla_model as xm
7075

7176
return xm.xrt_world_size()
@@ -82,6 +87,11 @@ def global_rank(self) -> int:
8287
The output is cached for performance.
8388
8489
"""
90+
if _XLA_GREATER_EQUAL_2_1:
91+
from torch_xla import runtime as xr
92+
93+
return xr.global_ordinal()
94+
8595
import torch_xla.core.xla_model as xm
8696

8797
return xm.get_ordinal()
@@ -98,6 +108,11 @@ def local_rank(self) -> int:
98108
The output is cached for performance.
99109
100110
"""
111+
if _XLA_GREATER_EQUAL_2_1:
112+
from torch_xla import runtime as xr
113+
114+
return xr.local_ordinal()
115+
101116
import torch_xla.core.xla_model as xm
102117

103118
return xm.get_local_ordinal()

0 commit comments

Comments
 (0)