Skip to content

Commit e10118e

Browse files
authored
fix: Disable MPS for Torch versions >=2.8.0 (#287)
* Updated device selection to exclude MPS for certain Torch versions * Updated device selection to exclude MPS for certain Torch versions
1 parent 66c30a5 commit e10118e

File tree

2 files changed

+54
-29
lines changed

2 files changed

+54
-29
lines changed

model2vec/distill/utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,47 @@
33
from logging import getLogger
44

55
import torch
6+
from packaging import version
67

78
logger = getLogger(__name__)
89

910

1011
def select_optimal_device(device: str | None) -> str:
1112
"""
12-
Guess what your optimal device should be based on backend availability.
13+
Get the optimal device to use based on backend availability.
1314
14-
If you pass a device, we just pass it through.
15+
For Torch versions >= 2.8.0, MPS is disabled due to known performance regressions.
1516
16-
:param device: The device to use. If this is not None you get back what you passed.
17+
:param device: The device to use. If this is None, the device is automatically selected.
1718
:return: The selected device.
19+
:raises RuntimeError: If MPS is requested on a PyTorch version where it is disabled.
1820
"""
19-
if device is None:
20-
if torch.cuda.is_available():
21-
device = "cuda"
22-
elif torch.backends.mps.is_available():
23-
device = "mps"
21+
# Get the torch version and check if MPS is broken
22+
torch_version = version.parse(torch.__version__.split("+")[0])
23+
mps_broken = torch_version >= version.parse("2.8.0")
24+
25+
if device:
26+
if device == "mps" and mps_broken:
27+
raise RuntimeError(
28+
f"MPS is disabled for PyTorch {torch.__version__} due to known performance regressions. "
29+
"Please use CPU or CUDA instead, or use a PyTorch version < 2.8.0."
30+
)
2431
else:
32+
return device
33+
34+
if torch.cuda.is_available():
35+
device = "cuda"
36+
elif torch.backends.mps.is_available():
37+
if mps_broken:
38+
logger.warning(
39+
f"MPS is available but PyTorch {torch.__version__} has known performance regressions. "
40+
"Falling back to CPU. Please use a PyTorch version < 2.8.0 to enable MPS support."
41+
)
2542
device = "cpu"
26-
logger.info(f"Automatically selected device: {device}")
43+
else:
44+
device = "mps"
45+
else:
46+
device = "cpu"
2747

48+
logger.info(f"Automatically selected device: {device}")
2849
return device

tests/test_utils.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
from __future__ import annotations
22

3-
import json
43
from pathlib import Path
5-
from tempfile import NamedTemporaryFile, TemporaryDirectory
6-
from typing import Any
4+
from tempfile import NamedTemporaryFile
75
from unittest.mock import patch
86

9-
import numpy as np
107
import pytest
11-
import safetensors
12-
import safetensors.numpy
13-
from tokenizers import Tokenizer
148

159
from model2vec.distill.utils import select_optimal_device
1610
from model2vec.hf_utils import _get_metadata_from_readme
@@ -39,26 +33,36 @@ def test__get_metadata_from_readme_mocked_file_keys() -> None:
3933

4034

4135
@pytest.mark.parametrize(
42-
"device, expected, cuda, mps",
36+
"torch_version, device, expected, cuda, mps, should_raise",
4337
[
44-
("cpu", "cpu", True, True),
45-
("cpu", "cpu", True, False),
46-
("cpu", "cpu", False, True),
47-
("cpu", "cpu", False, False),
48-
("clown", "clown", False, False),
49-
(None, "cuda", True, True),
50-
(None, "cuda", True, False),
51-
(None, "mps", False, True),
52-
(None, "cpu", False, False),
38+
("2.7.0", "cpu", "cpu", True, True, False),
39+
("2.8.0", "cpu", "cpu", True, True, False),
40+
("2.7.0", "clown", "clown", False, False, False),
41+
("2.8.0", "clown", "clown", False, False, False),
42+
("2.7.0", "mps", "mps", False, True, False),
43+
("2.8.0", "mps", None, False, True, True),
44+
("2.7.0", None, "cuda", True, True, False),
45+
("2.7.0", None, "mps", False, True, False),
46+
("2.7.0", None, "cpu", False, False, False),
47+
("2.8.0", None, "cuda", True, True, False),
48+
("2.8.0", None, "cpu", False, True, False),
49+
("2.8.0", None, "cpu", False, False, False),
50+
("2.9.0", None, "cpu", False, True, False),
51+
("3.0.0", None, "cpu", False, True, False),
5352
],
5453
)
55-
def test_select_optimal_device(device: str | None, expected: str, cuda: bool, mps: bool) -> None:
56-
"""Test whether the optimal device is selected."""
54+
def test_select_optimal_device(torch_version, device, expected, cuda, mps, should_raise) -> None:
55+
"""Test whether the optimal device is selected across versions and backends."""
5756
with (
5857
patch("torch.cuda.is_available", return_value=cuda),
5958
patch("torch.backends.mps.is_available", return_value=mps),
59+
patch("torch.__version__", torch_version),
6060
):
61-
assert select_optimal_device(device) == expected
61+
if should_raise:
62+
with pytest.raises(RuntimeError):
63+
select_optimal_device(device)
64+
else:
65+
assert select_optimal_device(device) == expected
6266

6367

6468
def test_importable() -> None:

0 commit comments

Comments
 (0)