Skip to content

Commit 990e661

Browse files
committed
add torch 2.2.2 or below compatibility
1 parent e249722 commit 990e661

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/lightning/fabric/utilities/device_dtype_mixin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ class _DeviceDtypeModuleMixin(Module):
2525
def __init__(self) -> None:
2626
super().__init__()
2727
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
28-
self._device: torch.device = torch.get_default_device()
28+
# Workarounds from the original pytorch issue:
29+
# https://github.com/pytorch/pytorch/issues/115333#issuecomment-1848449687
30+
self._device = torch.get_default_device() if torch.__version__ >= "2.3.0" else torch.empty(0).device
2931

3032
@property
3133
def dtype(self) -> Union[str, torch.dtype]:

tests/tests_fabric/utilities/test_device_dtype_mixin.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,27 @@ def test_submodules_device_and_dtype(dst_device_str, dst_type):
5050
assert model.dtype == model.module.module.dtype == dst_type
5151

5252

53+
@pytest.mark.parametrize(
54+
"dst_device_str",
55+
[
56+
"cpu",
57+
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
58+
pytest.param("mps:0", marks=RunIf(mps=True)),
59+
],
60+
)
61+
@pytest.mark.parametrize("dst_type", [torch.half, torch.float, torch.double])
62+
def test_submodules_context_device_and_dtype(dst_device_str, dst_type):
63+
if dst_device_str == "mps:0" and dst_type in (torch.half, torch.double):
64+
pytest.skip("MPS does not yet support half and double.")
65+
66+
dst_device = torch.device(dst_device_str)
67+
torch.set_default_dtype(dst_type)
68+
with dst_device:
69+
model = TopModule()
70+
assert model.device == dst_device
71+
assert model.dtype == dst_type
72+
73+
5374
@pytest.mark.parametrize(
5475
"device",
5576
[

0 commit comments

Comments
 (0)