Skip to content

Commit a1bf2ee

Browse files
committed
fix test.
1 parent fcd8dc9 commit a1bf2ee

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

src/lightning/pytorch/accelerators/mps.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,21 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
9393
def device_name(cls, device: Optional[_DEVICE] = None) -> str:
9494
if not cls.is_available():
9595
return ""
96-
try:
97-
result = subprocess.run(
98-
["sysctl", "-n", "machdep.cpu.brand_string"],
99-
capture_output=True,
100-
text=True,
101-
check=True,
102-
)
103-
result_str = result.stdout.strip()
104-
except (subprocess.SubprocessError, FileNotFoundError):
105-
result_str = "True (mps)"
106-
return result_str
96+
return _get_mps_device_name()
97+
98+
99+
def _get_mps_device_name() -> str:
100+
try:
101+
result = subprocess.run(
102+
["sysctl", "-n", "machdep.cpu.brand_string"],
103+
capture_output=True,
104+
text=True,
105+
check=True,
106+
)
107+
result_str = result.stdout.strip()
108+
except subprocess.SubprocessError:
109+
result_str = "True (mps)"
110+
return result_str
107111

108112

109113
# device metrics

tests/tests_pytorch/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ def cuda_count_4(monkeypatch):
208208
def mock_mps_count(monkeypatch, n: int) -> None:
209209
monkeypatch.setattr(lightning.fabric.accelerators.mps, "_get_all_available_mps_gpus", lambda: [0] if n > 0 else [])
210210
monkeypatch.setattr(lightning.fabric.accelerators.mps.MPSAccelerator, "is_available", lambda *_: n > 0)
211+
monkeypatch.setattr(
212+
lightning.pytorch.accelerators.mps, "_get_mps_device_name", lambda: "Mocked MPS Device" if n > 0 else ""
213+
)
211214

212215

213216
@pytest.fixture

0 commit comments

Comments
 (0)