|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import operator |
| 15 | +import os |
15 | 16 | from collections import namedtuple |
| 17 | +from unittest import mock |
16 | 18 | from unittest.mock import patch |
17 | 19 |
|
18 | 20 | import pytest |
|
21 | 23 | import tests.helpers.pipelines as tpipes |
22 | 24 | import tests.helpers.utils as tutils |
23 | 25 | from pytorch_lightning import Trainer |
| 26 | +from pytorch_lightning.plugins.environments import TorchElasticEnvironment |
24 | 27 | from pytorch_lightning.utilities import device_parser |
25 | 28 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
26 | 29 | from pytorch_lightning.utilities.imports import _compare_version |
@@ -219,6 +222,29 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun |
219 | 222 | device_parser.parse_gpu_ids(gpus) |
220 | 223 |
|
221 | 224 |
|
| 225 | +@mock.patch.dict( |
| 226 | + os.environ, { |
| 227 | + "CUDA_VISIBLE_DEVICES": "0", |
| 228 | + "LOCAL_RANK": "1", |
| 229 | + "GROUP_RANK": "1", |
| 230 | + "RANK": "3", |
| 231 | + "WORLD_SIZE": "4", |
| 232 | + "LOCAL_WORLD_SIZE": "2", |
| 233 | + } |
| 234 | +) |
| 235 | +@mock.patch('torch.cuda.device_count', return_value=1) |
| 236 | +@pytest.mark.parametrize("gpus", [[0, 1, 2], 2, '0']) |
| 237 | +def test_torchelastic_gpu_parsing(mocked_device_count, gpus): |
| 238 | + """ |
| 239 | + Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device |
| 240 | + That we omit sanitizing the gpus as only one of the GPUs is visible. |
| 241 | + """ |
| 242 | + trainer = Trainer(gpus=gpus) |
| 243 | + assert isinstance(trainer.accelerator_connector.cluster_environment, TorchElasticEnvironment) |
| 244 | + assert trainer.accelerator_connector.parallel_device_ids == device_parser.parse_gpu_ids(gpus) |
| 245 | + assert trainer.gpus == gpus |
| 246 | + |
| 247 | + |
222 | 248 | @RunIf(min_gpus=1) |
223 | 249 | def test_single_gpu_batch_parse(): |
224 | 250 | trainer = Trainer(gpus=1) |
|
0 commit comments