Skip to content

Commit 6822960

Browse files
awaelchlilexierule
authored andcommitted
Fix support for passing -1 to find_usable_cuda_devices function (#16866)
Co-authored-by: Yi Heng Lim <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent b0c7908 commit 6822960

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Fixed edge cases in parsing device ids using NVML ([#16795](https://github.com/Lightning-AI/lightning/pull/16795))
1313
- Fixed DDP spawn hang on TPU Pods ([#16844](https://github.com/Lightning-AI/lightning/pull/16844))
14+
- Fixed an error when passing `find_usable_cuda_devices(num_devices=-1)` ([#16866](https://github.com/Lightning-AI/lightning/pull/16866))
1415

1516

1617
## [1.9.3] - 2023-02-21

src/lightning_fabric/accelerators/cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def find_usable_cuda_devices(num_devices: int = -1) -> List[int]:
114114
# exit early if we found the right number of GPUs
115115
break
116116

117-
if len(available_devices) != num_devices:
117+
if num_devices != -1 and len(available_devices) != num_devices:
118118
raise RuntimeError(
119119
f"You requested to find {num_devices} devices but only {len(available_devices)} are currently available."
120120
f" The devices {unavailable_devices} are occupied by other processes and can't be used at the moment."

tests/tests_fabric/accelerators/test_cuda.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,9 @@ def test_find_usable_cuda_devices_error_handling():
143143
"lightning_fabric.accelerators.cuda.torch.tensor", tensor_mock
144144
), pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")):
145145
find_usable_cuda_devices(2)
146+
147+
# Request for as many GPUs as there are, no error should be raised
148+
with mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=5), mock.patch(
149+
"lightning_fabric.accelerators.cuda.torch.tensor"
150+
):
151+
assert find_usable_cuda_devices(-1) == [0, 1, 2, 3, 4]

0 commit comments

Comments
 (0)