Skip to content

Commit c323507

Browse files
committed
fix: restore torch default dtype once test_submodules_context_device_and_dtype is finished.
1 parent dd12c59 commit c323507

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

tests/tests_fabric/utilities/test_device_dtype_mixin.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
from torch import nn as nn
44

5+
from lightning.fabric.plugins.precision.utils import _DtypeContextManager
56
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
67
from tests_fabric.helpers.runif import RunIf
78

@@ -58,14 +59,17 @@ def test_submodules_device_and_dtype(dst_device_str, dst_type):
5859
pytest.param("mps:0", marks=RunIf(mps=True)),
5960
],
6061
)
61-
@pytest.mark.parametrize("dst_type", [torch.half, torch.float, torch.double])
62+
@pytest.mark.parametrize(
63+
"dst_type",
64+
[
65+
torch.float,
66+
pytest.param(torch.half, marks=RunIf(mps=False)),
67+
pytest.param(torch.double, marks=RunIf(mps=False)),
68+
],
69+
)
6270
def test_submodules_context_device_and_dtype(dst_device_str, dst_type):
63-
if dst_device_str == "mps:0" and dst_type in (torch.half, torch.double):
64-
pytest.skip("MPS does not yet support half and double.")
65-
6671
dst_device = torch.device(dst_device_str)
67-
torch.set_default_dtype(dst_type)
68-
with dst_device:
72+
with _DtypeContextManager(dst_type), dst_device:
6973
model = TopModule()
7074
assert model.device == dst_device
7175
assert model.dtype == dst_type

0 commit comments

Comments
 (0)