Skip to content

Commit e356426

Browse files
committed
add unittests
1 parent eeedb32 commit e356426

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/tests_pytorch/accelerators/test_mps.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from collections import namedtuple
16+
from subprocess import SubprocessError
1617
from unittest import mock
1718

1819
import pytest
@@ -21,6 +22,7 @@
2122
import tests_pytorch.helpers.pipelines as tpipes
2223
from lightning.pytorch import Trainer
2324
from lightning.pytorch.accelerators import MPSAccelerator
25+
from lightning.pytorch.accelerators.mps import _get_mps_device_name
2426
from lightning.pytorch.demos.boring_classes import BoringModel
2527
from tests_pytorch.helpers.runif import RunIf
2628

@@ -151,3 +153,21 @@ def to(self, *args, **kwargs):
151153

152154
batch = trainer.strategy.batch_to_device(CustomBatchType(), torch.device("mps"))
153155
assert batch.a.type() == "torch.mps.FloatTensor"
156+
157+
158+
@mock.patch("lightning.pytorch.accelerators.mps.subprocess.run")
159+
def test_get_mps_device_name(mock_run):
160+
mock_stdout = mock.MagicMock()
161+
mock_stdout.configure_mock(stdout="Apple M1 Pro\n")
162+
163+
mock_run.return_value = mock_stdout
164+
device_name = _get_mps_device_name()
165+
assert device_name == "Apple M1 Pro"
166+
167+
168+
@mock.patch(
169+
"lightning.pytorch.accelerators.mps.subprocess.run",
170+
side_effect=SubprocessError("test"),
171+
)
172+
def test_get_mps_device_name_exception(mock_run):
173+
assert _get_mps_device_name() == "True (mps)"

0 commit comments

Comments
 (0)