Skip to content

Commit 520c1e4

Browse files
authored
Add @override for files in src/lightning/fabric/plugins/environments (#19098)
1 parent c5363af commit 520c1e4

File tree

7 files changed

+85
-0
lines changed

7 files changed

+85
-0
lines changed

src/lightning/fabric/plugins/environments/kubeflow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import logging
1616
import os
1717

18+
from typing_extensions import override
19+
1820
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
1921

2022
log = logging.getLogger(__name__)
@@ -32,35 +34,45 @@ class KubeflowEnvironment(ClusterEnvironment):
3234
"""
3335

3436
@property
37+
@override
3538
def creates_processes_externally(self) -> bool:
3639
return True
3740

3841
@property
42+
@override
3943
def main_address(self) -> str:
4044
return os.environ["MASTER_ADDR"]
4145

4246
@property
47+
@override
4348
def main_port(self) -> int:
4449
return int(os.environ["MASTER_PORT"])
4550

4651
@staticmethod
52+
@override
4753
def detect() -> bool:
4854
raise NotImplementedError("The Kubeflow environment can't be detected automatically.")
4955

56+
@override
5057
def world_size(self) -> int:
5158
return int(os.environ["WORLD_SIZE"])
5259

60+
@override
5361
def set_world_size(self, size: int) -> None:
5462
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
5563

64+
@override
5665
def global_rank(self) -> int:
5766
return int(os.environ["RANK"])
5867

68+
@override
5969
def set_global_rank(self, rank: int) -> None:
6070
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
6171

72+
@override
6273
def local_rank(self) -> int:
6374
return 0
6475

76+
@override
6577
def node_rank(self) -> int:
6678
return self.global_rank()

src/lightning/fabric/plugins/environments/lightning.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import os
1616
import socket
1717

18+
from typing_extensions import override
19+
1820
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
1921
from lightning.fabric.utilities.rank_zero import rank_zero_only
2022

@@ -42,6 +44,7 @@ def __init__(self) -> None:
4244
self._world_size: int = 1
4345

4446
@property
47+
@override
4548
def creates_processes_externally(self) -> bool:
4649
"""Returns whether the cluster creates the processes or not.
4750
@@ -52,10 +55,12 @@ def creates_processes_externally(self) -> bool:
5255
return "LOCAL_RANK" in os.environ
5356

5457
@property
58+
@override
5559
def main_address(self) -> str:
5660
return os.environ.get("MASTER_ADDR", "127.0.0.1")
5761

5862
@property
63+
@override
5964
def main_port(self) -> int:
6065
if self._main_port == -1:
6166
self._main_port = (
@@ -64,29 +69,37 @@ def main_port(self) -> int:
6469
return self._main_port
6570

6671
@staticmethod
72+
@override
6773
def detect() -> bool:
6874
return True
6975

76+
@override
7077
def world_size(self) -> int:
7178
return self._world_size
7279

80+
@override
7381
def set_world_size(self, size: int) -> None:
7482
self._world_size = size
7583

84+
@override
7685
def global_rank(self) -> int:
7786
return self._global_rank
7887

88+
@override
7989
def set_global_rank(self, rank: int) -> None:
8090
self._global_rank = rank
8191
rank_zero_only.rank = rank
8292

93+
@override
8394
def local_rank(self) -> int:
8495
return int(os.environ.get("LOCAL_RANK", 0))
8596

97+
@override
8698
def node_rank(self) -> int:
8799
group_rank = os.environ.get("GROUP_RANK", 0)
88100
return int(os.environ.get("NODE_RANK", group_rank))
89101

102+
@override
90103
def teardown(self) -> None:
91104
if "WORLD_SIZE" in os.environ:
92105
del os.environ["WORLD_SIZE"]

src/lightning/fabric/plugins/environments/lsf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import socket
1717
from typing import Dict, List
1818

19+
from typing_extensions import override
20+
1921
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
2022
from lightning.fabric.utilities.cloud_io import get_filesystem
2123

@@ -62,27 +64,32 @@ def _set_init_progress_group_env_vars(self) -> None:
6264
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
6365

6466
@property
67+
@override
6568
def creates_processes_externally(self) -> bool:
6669
"""LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them."""
6770
return True
6871

6972
@property
73+
@override
7074
def main_address(self) -> str:
7175
"""The main address is read from an OpenMPI host rank file in the environment variable
7276
``LSB_DJOB_RANKFILE``."""
7377
return self._main_address
7478

7579
@property
80+
@override
7681
def main_port(self) -> int:
7782
"""The main port is calculated from the LSF job ID."""
7883
return self._main_port
7984

8085
@staticmethod
86+
@override
8187
def detect() -> bool:
8288
"""Returns ``True`` if the current process was launched using the ``jsrun`` command."""
8389
required_env_vars = {"LSB_JOBID", "LSB_DJOB_RANKFILE", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
8490
return required_env_vars.issubset(os.environ.keys())
8591

92+
@override
8693
def world_size(self) -> int:
8794
"""The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``."""
8895
world_size = os.environ.get("JSM_NAMESPACE_SIZE")
@@ -93,9 +100,11 @@ def world_size(self) -> int:
93100
)
94101
return int(world_size)
95102

103+
@override
96104
def set_world_size(self, size: int) -> None:
97105
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
98106

107+
@override
99108
def global_rank(self) -> int:
100109
"""The world size is read from the environment variable ``JSM_NAMESPACE_RANK``."""
101110
global_rank = os.environ.get("JSM_NAMESPACE_RANK")
@@ -106,9 +115,11 @@ def global_rank(self) -> int:
106115
)
107116
return int(global_rank)
108117

118+
@override
109119
def set_global_rank(self, rank: int) -> None:
110120
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
111121

122+
@override
112123
def local_rank(self) -> int:
113124
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
114125
local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK")
@@ -119,6 +130,7 @@ def local_rank(self) -> int:
119130
)
120131
return int(local_rank)
121132

133+
@override
122134
def node_rank(self) -> int:
123135
"""The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored in
124136
``LSB_DJOB_RANKFILE``."""

src/lightning/fabric/plugins/environments/mpi.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Optional
1919

2020
from lightning_utilities.core.imports import RequirementCache
21+
from typing_extensions import override
2122

2223
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
2324
from lightning.fabric.plugins.environments.lightning import find_free_network_port
@@ -47,22 +48,26 @@ def __init__(self) -> None:
4748
self._main_port: Optional[int] = None
4849

4950
@property
51+
@override
5052
def creates_processes_externally(self) -> bool:
5153
return True
5254

5355
@property
56+
@override
5457
def main_address(self) -> str:
5558
if self._main_address is None:
5659
self._main_address = self._get_main_address()
5760
return self._main_address
5861

5962
@property
63+
@override
6064
def main_port(self) -> int:
6165
if self._main_port is None:
6266
self._main_port = self._get_main_port()
6367
return self._main_port
6468

6569
@staticmethod
70+
@override
6671
def detect() -> bool:
6772
"""Returns ``True`` if the `mpi4py` package is installed and MPI returns a world size greater than 1."""
6873
if not _MPI4PY_AVAILABLE:
@@ -72,27 +77,33 @@ def detect() -> bool:
7277

7378
return MPI.COMM_WORLD.Get_size() > 1
7479

80+
@override
7581
@lru_cache(1)
7682
def world_size(self) -> int:
7783
return self._comm_world.Get_size()
7884

85+
@override
7986
def set_world_size(self, size: int) -> None:
8087
log.debug("MPIEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
8188

89+
@override
8290
@lru_cache(1)
8391
def global_rank(self) -> int:
8492
return self._comm_world.Get_rank()
8593

94+
@override
8695
def set_global_rank(self, rank: int) -> None:
8796
log.debug("MPIEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
8897

98+
@override
8999
@lru_cache(1)
90100
def local_rank(self) -> int:
91101
if self._comm_local is None:
92102
self._init_comm_local()
93103
assert self._comm_local is not None
94104
return self._comm_local.Get_rank()
95105

106+
@override
96107
def node_rank(self) -> int:
97108
if self._node_rank is None:
98109
self._init_comm_local()

src/lightning/fabric/plugins/environments/slurm.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import sys
2121
from typing import Optional
2222

23+
from typing_extensions import override
24+
2325
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
2426
from lightning.fabric.utilities.imports import _IS_WINDOWS
2527
from lightning.fabric.utilities.rank_zero import rank_zero_warn
@@ -52,10 +54,12 @@ def __init__(self, auto_requeue: bool = True, requeue_signal: Optional[signal.Si
5254
self._validate_srun_variables()
5355

5456
@property
57+
@override
5558
def creates_processes_externally(self) -> bool:
5659
return True
5760

5861
@property
62+
@override
5963
def main_address(self) -> str:
6064
root_node = os.environ.get("MASTER_ADDR")
6165
if root_node is None:
@@ -67,6 +71,7 @@ def main_address(self) -> str:
6771
return root_node
6872

6973
@property
74+
@override
7075
def main_port(self) -> int:
7176
# -----------------------
7277
# SLURM JOB = PORT number
@@ -94,6 +99,7 @@ def main_port(self) -> int:
9499
return default_port
95100

96101
@staticmethod
102+
@override
97103
def detect() -> bool:
98104
"""Returns ``True`` if the current process was launched on a SLURM cluster.
99105
@@ -124,24 +130,31 @@ def job_id() -> Optional[int]:
124130
except ValueError:
125131
return None
126132

133+
@override
127134
def world_size(self) -> int:
128135
return int(os.environ["SLURM_NTASKS"])
129136

137+
@override
130138
def set_world_size(self, size: int) -> None:
131139
log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
132140

141+
@override
133142
def global_rank(self) -> int:
134143
return int(os.environ["SLURM_PROCID"])
135144

145+
@override
136146
def set_global_rank(self, rank: int) -> None:
137147
log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
138148

149+
@override
139150
def local_rank(self) -> int:
140151
return int(os.environ["SLURM_LOCALID"])
141152

153+
@override
142154
def node_rank(self) -> int:
143155
return int(os.environ["SLURM_NODEID"])
144156

157+
@override
145158
def validate_settings(self, num_devices: int, num_nodes: int) -> None:
146159
if _is_slurm_interactive_mode():
147160
return

0 commit comments

Comments
 (0)