Skip to content

Commit e95d8b1

Browse files
awaelchlitchatoncarmoccaBordaAndrew Tritt
authored andcommitted
Modify LSFEnvironment to use more reliable environment variable (#10825)
Co-authored-by: thomas chaton <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Andrew Tritt <[email protected]>
1 parent 8cffc0f commit e95d8b1

File tree

3 files changed

+197
-116
lines changed

3 files changed

+197
-116
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Skip testing with PyTorch 1.7 and Python 3.9 on Ubuntu ([#11217](https://github.com/PyTorchLightning/pytorch-lightning/pull/11217))
1313
- Fixed type promotion when tensors of higher category than float are logged ([#11401](https://github.com/PyTorchLightning/pytorch-lightning/pull/11401))
1414

15+
### Changed
16+
17+
- Changed `LSFEnvironment` to use `LSB_DJOB_RANKFILE` environment variable instead of `LSB_HOSTS` for determining node rank and main address ([#10825](https://github.com/PyTorchLightning/pytorch-lightning/pull/10825))
18+
1519

1620
## [1.5.8] - 2022-01-05
1721

pytorch_lightning/plugins/environments/lsf_environment.py

Lines changed: 112 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
import os
1616
import socket
17+
from typing import Dict, List
1718

1819
from pytorch_lightning import _logger as log
1920
from pytorch_lightning.plugins.environments import ClusterEnvironment
21+
from pytorch_lightning.utilities import rank_zero_deprecation
22+
from pytorch_lightning.utilities.cloud_io import get_filesystem
2023

2124

2225
class LSFEnvironment(ClusterEnvironment):
@@ -25,128 +28,161 @@ class LSFEnvironment(ClusterEnvironment):
2528
It is expected that any execution using this ClusterEnvironment was executed
2629
using the Job Step Manager i.e. ``jsrun``.
2730
28-
This plugin expects the following environment variables.
31+
This plugin expects the following environment variables:
2932
30-
LSB_JOBID:
31-
The LSF assigned job ID
33+
``LSB_JOBID``
34+
The LSF assigned job ID
3235
33-
LSB_HOSTS:
34-
The hosts used in the job. This string is expected to have the format "batch <rank_0_host> ...."
36+
``LSB_DJOB_RANKFILE``
37+
The OpenMPI compatibile rank file for the LSF job
3538
36-
JSM_NAMESPACE_LOCAL_RANK:
37-
The node local rank for the task. This environment variable is set by jsrun
39+
``JSM_NAMESPACE_LOCAL_RANK``
40+
The node local rank for the task. This environment variable is set by ``jsrun``
3841
39-
JSM_NAMESPACE_SIZE:
40-
The world size for the task. This environment variable is set by jsrun
41-
"""
42+
``JSM_NAMESPACE_SIZE``
43+
The world size for the task. This environment variable is set by ``jsrun``
4244
43-
def __init__(self):
44-
self._master_address = self._get_master_address()
45-
self._master_port = self._get_master_port()
46-
log.debug(f"MASTER_ADDR: {self._master_address}")
47-
log.debug(f"MASTER_PORT: {self._master_port}")
45+
``JSM_NAMESPACE_RANK``
46+
The global rank for the task. This environment variable is set by ``jsrun``
47+
"""
4848

49-
@staticmethod
50-
def is_using_lsf() -> bool:
51-
"""Returns ``True`` if the current process was launched using the jsrun command."""
52-
required_env_vars = ("LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE")
53-
return all(v in os.environ for v in required_env_vars)
49+
def __init__(self) -> None:
50+
super().__init__()
51+
# TODO: remove in 1.7
52+
if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf):
53+
rank_zero_deprecation(
54+
f"`{self.__class__.__name__}.is_using_lsf` has been deprecated in v1.6 and will be removed in v1.7."
55+
" Implement the static method `detect()` instead (do not forget to add the `@staticmethod` decorator)."
56+
)
57+
self._main_address = self._get_main_address()
58+
self._main_port = self._get_main_port()
59+
self._node_rank = self._get_node_rank()
60+
self._set_init_progress_group_env_vars()
61+
62+
def _set_init_progress_group_env_vars(self) -> None:
63+
# set environment variables needed for initializing torch distributed process group
64+
os.environ["MASTER_ADDR"] = str(self._main_address)
65+
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
66+
os.environ["MASTER_PORT"] = str(self._main_port)
67+
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
5468

5569
@property
5670
def creates_processes_externally(self) -> bool:
71+
"""LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them."""
5772
return True
5873

59-
def master_address(self):
60-
"""The master address is read from a list of hosts contained in the environment variable `LSB_HOSTS`."""
61-
return self._master_address
74+
def master_address(self) -> str:
75+
"""The main address is read from an OpenMPI host rank file in the environment variable
76+
``LSB_DJOB_RANKFILE``."""
77+
return self._main_address
78+
79+
def master_port(self) -> int:
80+
"""The main port is calculated from the LSF job ID."""
81+
return self._main_port
6282

63-
def master_port(self):
64-
"""THe master port gets calculated from the LSF job ID."""
65-
return self._master_port
83+
@staticmethod
84+
def is_using_lsf() -> bool:
85+
"""Returns ``True`` if the current process was launched using the ``jsrun`` command."""
86+
required_env_vars = {"LSB_JOBID", "LSB_DJOB_RANKFILE", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
87+
return required_env_vars.issubset(os.environ.keys())
6688

67-
def world_size(self):
68-
"""The world size is read from the environment variable `JSM_NAMESPACE_SIZE`."""
69-
var = "JSM_NAMESPACE_SIZE"
70-
world_size = os.environ.get(var)
89+
def world_size(self) -> int:
90+
"""The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``."""
91+
world_size = os.environ.get("JSM_NAMESPACE_SIZE")
7192
if world_size is None:
7293
raise ValueError(
73-
f"Cannot determine world size from environment variable {var}."
74-
" Make sure you run your executable with `jsrun`"
94+
"Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
95+
"Make sure you run your executable with `jsrun`."
7596
)
7697
return int(world_size)
7798

7899
def set_world_size(self, size: int) -> None:
79100
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
80101

81-
def global_rank(self):
82-
"""The world size is read from the environment variable `JSM_NAMESPACE_RANK`."""
83-
var = "JSM_NAMESPACE_RANK"
84-
global_rank = os.environ.get(var)
102+
def global_rank(self) -> int:
103+
"""The world size is read from the environment variable ``JSM_NAMESPACE_RANK``."""
104+
global_rank = os.environ.get("JSM_NAMESPACE_RANK")
85105
if global_rank is None:
86106
raise ValueError(
87-
f"Cannot determine global rank from environment variable {var}."
88-
" Make sure you run your executable with `jsrun`"
107+
"Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
108+
"Make sure you run your executable with `jsrun`."
89109
)
90110
return int(global_rank)
91111

92112
def set_global_rank(self, rank: int) -> None:
93113
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
94114

95-
def local_rank(self):
115+
def local_rank(self) -> int:
96116
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
97-
var = "JSM_NAMESPACE_LOCAL_RANK"
98-
local_rank = os.environ.get(var)
117+
local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK")
99118
if local_rank is None:
100119
raise ValueError(
101-
f"Cannot determine local rank from environment variable {var}."
102-
" Make sure you run your executable with `jsrun`"
120+
"Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
121+
"Make sure you run your executable with `jsrun`."
103122
)
104123
return int(local_rank)
105124

106-
def node_rank(self):
107-
"""The node rank is determined by the position of the current hostname in the list of hosts stored in the
108-
environment variable `LSB_HOSTS`."""
125+
def node_rank(self) -> int:
126+
"""The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored
127+
in ``LSB_DJOB_RANKFILE``."""
128+
return self._node_rank
129+
130+
def _get_node_rank(self) -> int:
131+
"""A helper method for getting the node rank.
132+
133+
The node rank is determined by the position of the current node in the list of hosts used in the job. This is
134+
calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list.
135+
"""
109136
hosts = self._read_hosts()
110-
count = {}
137+
count: Dict[str, int] = {}
111138
for host in hosts:
112-
if "batch" in host or "login" in host:
113-
continue
114139
if host not in count:
115140
count[host] = len(count)
116141
return count[socket.gethostname()]
117142

118143
@staticmethod
119-
def _read_hosts():
120-
hosts = os.environ.get("LSB_HOSTS")
121-
if not hosts:
122-
raise ValueError("Could not find hosts in environment variable LSB_HOSTS")
123-
hosts = hosts.split()
124-
if len(hosts) < 2:
125-
raise ValueError(
126-
'Cannot parse hosts from LSB_HOSTS environment variable. Expected format: "batch <rank_0_host> ..."'
127-
)
128-
return hosts
144+
def _read_hosts() -> List[str]:
145+
"""Read compute hosts that are a part of the compute job.
129146
130-
def _get_master_address(self):
147+
LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes.
148+
Each job is assigned a launch node. This launch node will be the first node in the list contained in
149+
``LSB_DJOB_RANKFILE``.
150+
"""
151+
var = "LSB_DJOB_RANKFILE"
152+
rankfile = os.environ.get(var)
153+
if rankfile is None:
154+
raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`")
155+
if not rankfile:
156+
raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty")
157+
158+
fs = get_filesystem(rankfile)
159+
with fs.open(rankfile, "r") as f:
160+
ret = [line.strip() for line in f]
161+
# remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list
162+
return ret[1:]
163+
164+
def _get_main_address(self) -> str:
165+
"""A helper for getting the main address.
166+
167+
The main address is assigned to the first node in the list of nodes used for the job.
168+
"""
131169
hosts = self._read_hosts()
132-
return hosts[1]
170+
return hosts[0]
133171

134172
@staticmethod
135-
def _get_master_port():
136-
"""A helper function for accessing the master port.
173+
def _get_main_port() -> int:
174+
"""A helper function for accessing the main port.
137175
138-
Uses the LSF job ID so all ranks can compute the master port.
176+
Uses the LSF job ID so all ranks can compute the main port.
139177
"""
140-
# check for user-specified master port
141-
port = os.environ.get("MASTER_PORT")
142-
if not port:
143-
jobid = os.environ.get("LSB_JOBID")
144-
if not jobid:
145-
raise ValueError("Could not find job id in environment variable LSB_JOBID")
146-
port = int(jobid)
178+
# check for user-specified main port
179+
if "MASTER_PORT" in os.environ:
180+
log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
181+
return int(os.environ["MASTER_PORT"])
182+
if "LSB_JOBID" in os.environ:
183+
port = int(os.environ["LSB_JOBID"])
147184
# all ports should be in the 10k+ range
148-
port = int(port) % 1000 + 10000
149-
log.debug(f"calculated LSF master port: {port}")
150-
else:
151-
log.debug(f"using externally specified master port: {port}")
152-
return int(port)
185+
port = port % 1000 + 10000
186+
log.debug(f"calculated LSF main port: {port}")
187+
return port
188+
raise ValueError("Could not find job id in environment variable LSB_JOBID")

0 commit comments

Comments
 (0)