-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_mlx_float64_warning.py
More file actions
108 lines (83 loc) · 3.33 KB
/
test_mlx_float64_warning.py
File metadata and controls
108 lines (83 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import warnings
import numpy as np
import pytest
try:
import mlx.core as mx
from adapters import MLXAdapter
from kabsch_horn import mlx as kabsch_mlx
_MLX_AVAILABLE = True
except ImportError:
_MLX_AVAILABLE = False
pytestmark = pytest.mark.skipif(not _MLX_AVAILABLE, reason="MLX not available")
_RNG = np.random.default_rng(0)
_P_NP = _RNG.random((8, 3)).astype(np.float64)
_Q_NP = _RNG.random((8, 3)).astype(np.float64)
@pytest.fixture
def P():
return mx.array(_P_NP, dtype=mx.float64) # type: ignore[name-defined]
@pytest.fixture
def Q():
return mx.array(_Q_NP, dtype=mx.float64) # type: ignore[name-defined]
_WARN_FNS = (
[
pytest.param(kabsch_mlx.kabsch, id="kabsch"),
pytest.param(kabsch_mlx.kabsch_umeyama, id="kabsch_umeyama"),
pytest.param(kabsch_mlx.kabsch_rmsd, id="kabsch_rmsd"),
pytest.param(kabsch_mlx.kabsch_umeyama_rmsd, id="kabsch_umeyama_rmsd"),
pytest.param(kabsch_mlx.horn, id="horn"),
pytest.param(kabsch_mlx.horn_with_scale, id="horn_with_scale"),
]
if _MLX_AVAILABLE
else []
)
_NO_WARN_FNS = (
[
pytest.param(kabsch_mlx.kabsch, id="kabsch"),
pytest.param(kabsch_mlx.kabsch_umeyama, id="kabsch_umeyama"),
pytest.param(kabsch_mlx.kabsch_rmsd, id="kabsch_rmsd"),
pytest.param(kabsch_mlx.kabsch_umeyama_rmsd, id="kabsch_umeyama_rmsd"),
pytest.param(kabsch_mlx.horn, id="horn"),
pytest.param(kabsch_mlx.horn_with_scale, id="horn_with_scale"),
]
if _MLX_AVAILABLE
else []
)
_MIXED_WARN_FNS = (
[
pytest.param(kabsch_mlx.kabsch, id="kabsch"),
pytest.param(kabsch_mlx.kabsch_umeyama, id="kabsch_umeyama"),
pytest.param(kabsch_mlx.kabsch_rmsd, id="kabsch_rmsd"),
pytest.param(kabsch_mlx.kabsch_umeyama_rmsd, id="kabsch_umeyama_rmsd"),
pytest.param(kabsch_mlx.horn, id="horn"),
pytest.param(kabsch_mlx.horn_with_scale, id="horn_with_scale"),
]
if _MLX_AVAILABLE
else []
)
@pytest.mark.parametrize("fn", _WARN_FNS)
def test_float64_emits_user_warning(fn, P, Q):
"""float64 MLX inputs must emit a UserWarning about CPU fallback."""
with pytest.warns(UserWarning, match="float64"):
fn(P, Q)
@pytest.mark.parametrize("fn", _NO_WARN_FNS)
def test_float32_no_warning(fn):
"""float32 MLX inputs must not emit a float64 warning."""
P32 = mx.array(_P_NP.astype(np.float32)) # type: ignore[name-defined]
Q32 = mx.array(_Q_NP.astype(np.float32)) # type: ignore[name-defined]
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
fn(P32, Q32)
@pytest.mark.parametrize("fn", _MIXED_WARN_FNS)
def test_float32_p_float64_q_emits_warning(fn):
"""float32 P + float64 Q must emit a UserWarning (previously missed Q check)."""
P32 = mx.array(_P_NP.astype(np.float32)) # type: ignore[name-defined]
Q64 = mx.array(_Q_NP, dtype=mx.float64) # type: ignore[name-defined]
with pytest.warns(UserWarning, match="float64"):
fn(P32, Q64)
def test_mlx_adapter_float64_emits_warning():
"""Library warning fires when a float64 call is made through MLXAdapter."""
adapter = MLXAdapter("float64") # type: ignore[name-defined]
P = adapter.convert_in(_P_NP)
Q = adapter.convert_in(_Q_NP)
with pytest.warns(UserWarning, match="float64"):
adapter.kabsch(P, Q)