|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import json |
4 | 3 | from pathlib import Path |
5 | | -from tempfile import NamedTemporaryFile, TemporaryDirectory |
6 | | -from typing import Any |
| 4 | +from tempfile import NamedTemporaryFile |
7 | 5 | from unittest.mock import patch |
8 | 6 |
|
9 | | -import numpy as np |
10 | 7 | import pytest |
11 | | -import safetensors |
12 | | -import safetensors.numpy |
13 | | -from tokenizers import Tokenizer |
14 | 8 |
|
15 | 9 | from model2vec.distill.utils import select_optimal_device |
16 | 10 | from model2vec.hf_utils import _get_metadata_from_readme |
@@ -39,26 +33,36 @@ def test__get_metadata_from_readme_mocked_file_keys() -> None: |
39 | 33 |
|
40 | 34 |
|
41 | 35 | @pytest.mark.parametrize( |
42 | | - "device, expected, cuda, mps", |
| 36 | + "torch_version, device, expected, cuda, mps, should_raise", |
43 | 37 | [ |
44 | | - ("cpu", "cpu", True, True), |
45 | | - ("cpu", "cpu", True, False), |
46 | | - ("cpu", "cpu", False, True), |
47 | | - ("cpu", "cpu", False, False), |
48 | | - ("clown", "clown", False, False), |
49 | | - (None, "cuda", True, True), |
50 | | - (None, "cuda", True, False), |
51 | | - (None, "mps", False, True), |
52 | | - (None, "cpu", False, False), |
| 38 | + ("2.7.0", "cpu", "cpu", True, True, False), |
| 39 | + ("2.8.0", "cpu", "cpu", True, True, False), |
| 40 | + ("2.7.0", "clown", "clown", False, False, False), |
| 41 | + ("2.8.0", "clown", "clown", False, False, False), |
| 42 | + ("2.7.0", "mps", "mps", False, True, False), |
| 43 | + ("2.8.0", "mps", None, False, True, True), |
| 44 | + ("2.7.0", None, "cuda", True, True, False), |
| 45 | + ("2.7.0", None, "mps", False, True, False), |
| 46 | + ("2.7.0", None, "cpu", False, False, False), |
| 47 | + ("2.8.0", None, "cuda", True, True, False), |
| 48 | + ("2.8.0", None, "cpu", False, True, False), |
| 49 | + ("2.8.0", None, "cpu", False, False, False), |
| 50 | + ("2.9.0", None, "cpu", False, True, False), |
| 51 | + ("3.0.0", None, "cpu", False, True, False), |
53 | 52 | ], |
54 | 53 | ) |
55 | | -def test_select_optimal_device(device: str | None, expected: str, cuda: bool, mps: bool) -> None: |
56 | | - """Test whether the optimal device is selected.""" |
| 54 | +def test_select_optimal_device(torch_version, device, expected, cuda, mps, should_raise) -> None: |
| 55 | + """Test whether the optimal device is selected across versions and backends.""" |
57 | 56 | with ( |
58 | 57 | patch("torch.cuda.is_available", return_value=cuda), |
59 | 58 | patch("torch.backends.mps.is_available", return_value=mps), |
| 59 | + patch("torch.__version__", torch_version), |
60 | 60 | ): |
61 | | - assert select_optimal_device(device) == expected |
| 61 | + if should_raise: |
| 62 | + with pytest.raises(RuntimeError): |
| 63 | + select_optimal_device(device) |
| 64 | + else: |
| 65 | + assert select_optimal_device(device) == expected |
62 | 66 |
|
63 | 67 |
|
64 | 68 | def test_importable() -> None: |
|
0 commit comments