Skip to content

Commit f56df26

Browse files
Sean Narenlexierule
authored andcommitted
Add torchelastic check when sanitizing GPUs (#8095)
* Add torchelastic check * Add changelog * Address review * fix
1 parent 0de287b commit f56df26

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

pytorch_lightning/utilities/device_parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818

19+
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
1920
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_deprecation
2021
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2122
from pytorch_lightning.utilities.imports import _compare_version
@@ -78,6 +79,11 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
7879
gpus = _normalize_parse_gpu_input_to_list(gpus)
7980
if not gpus:
8081
raise MisconfigurationException("GPUs requested but none are available.")
82+
83+
if TorchElasticEnvironment.is_using_torchelastic() and len(gpus) != 1 and len(_get_all_available_gpus()) == 1:
84+
# omit sanity check on torchelastic as by default shows one visible GPU per process
85+
return gpus
86+
8187
gpus = _sanitize_gpu_ids(gpus)
8288

8389
return gpus

tests/models/test_gpu.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import operator
15+
import os
1516
from collections import namedtuple
17+
from unittest import mock
1618
from unittest.mock import patch
1719

1820
import pytest
@@ -21,6 +23,7 @@
2123
import tests.helpers.pipelines as tpipes
2224
import tests.helpers.utils as tutils
2325
from pytorch_lightning import Trainer
26+
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
2427
from pytorch_lightning.utilities import device_parser
2528
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2629
from pytorch_lightning.utilities.imports import _compare_version
@@ -219,6 +222,29 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun
219222
device_parser.parse_gpu_ids(gpus)
220223

221224

225+
@mock.patch.dict(
226+
os.environ, {
227+
"CUDA_VISIBLE_DEVICES": "0",
228+
"LOCAL_RANK": "1",
229+
"GROUP_RANK": "1",
230+
"RANK": "3",
231+
"WORLD_SIZE": "4",
232+
"LOCAL_WORLD_SIZE": "2",
233+
}
234+
)
235+
@mock.patch('torch.cuda.device_count', return_value=1)
236+
@pytest.mark.parametrize("gpus", [[0, 1, 2], 2, '0'])
237+
def test_torchelastic_gpu_parsing(mocked_device_count, gpus):
238+
"""
239+
Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device
240+
That we omit sanitizing the gpus as only one of the GPUs is visible.
241+
"""
242+
trainer = Trainer(gpus=gpus)
243+
assert isinstance(trainer.accelerator_connector.cluster_environment, TorchElasticEnvironment)
244+
assert trainer.accelerator_connector.parallel_device_ids == device_parser.parse_gpu_ids(gpus)
245+
assert trainer.gpus == gpus
246+
247+
222248
@RunIf(min_gpus=1)
223249
def test_single_gpu_batch_parse():
224250
trainer = Trainer(gpus=1)

0 commit comments

Comments
 (0)