Skip to content

Commit e615051

Browse files
committed
Updated device selection to exclude MPS for certain Torch versions
1 parent 9d201e5 commit e615051

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

model2vec/distill/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def select_optimal_device(device: str | None) -> str:
2626
if device == "mps" and mps_broken:
2727
raise RuntimeError(
2828
f"MPS is disabled for PyTorch {torch.__version__} due to known performance regressions. "
29-
"Please use CPU or CUDA instead."
29+
"Please use CPU or CUDA instead, or use a PyTorch version < 2.8.0."
3030
)
3131
else:
3232
return device
@@ -37,7 +37,7 @@ def select_optimal_device(device: str | None) -> str:
3737
if mps_broken:
3838
logger.warning(
3939
f"MPS is available but PyTorch {torch.__version__} has known performance regressions. "
40-
"Falling back to CPU."
40+
"Falling back to CPU. Please use a PyTorch version < 2.8.0 to enable MPS support."
4141
)
4242
device = "cpu"
4343
else:

0 commit comments

Comments
 (0)