|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | from collections import namedtuple |
| 16 | +from subprocess import SubprocessError |
16 | 17 | from unittest import mock |
17 | 18 |
|
18 | 19 | import pytest |
|
21 | 22 | import tests_pytorch.helpers.pipelines as tpipes |
22 | 23 | from lightning.pytorch import Trainer |
23 | 24 | from lightning.pytorch.accelerators import MPSAccelerator |
| 25 | +from lightning.pytorch.accelerators.mps import _get_mps_device_name |
24 | 26 | from lightning.pytorch.demos.boring_classes import BoringModel |
25 | 27 | from tests_pytorch.helpers.runif import RunIf |
26 | 28 |
|
@@ -151,3 +153,21 @@ def to(self, *args, **kwargs): |
151 | 153 |
|
152 | 154 | batch = trainer.strategy.batch_to_device(CustomBatchType(), torch.device("mps")) |
153 | 155 | assert batch.a.type() == "torch.mps.FloatTensor" |
| 156 | + |
| 157 | + |
| 158 | +@mock.patch("lightning.pytorch.accelerators.mps.subprocess.run") |
| 159 | +def test_get_mps_device_name(mock_run): |
| 160 | + mock_stdout = mock.MagicMock() |
| 161 | + mock_stdout.configure_mock(stdout="Apple M1 Pro\n") |
| 162 | + |
| 163 | + mock_run.return_value = mock_stdout |
| 164 | + device_name = _get_mps_device_name() |
| 165 | + assert device_name == "Apple M1 Pro" |
| 166 | + |
| 167 | + |
| 168 | +@mock.patch( |
| 169 | + "lightning.pytorch.accelerators.mps.subprocess.run", |
| 170 | + side_effect=SubprocessError("test"), |
| 171 | +) |
| 172 | +def test_get_mps_device_name_exception(mock_run): |
| 173 | + assert _get_mps_device_name() == "True (mps)" |
0 commit comments