-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_reference_validation.py
More file actions
117 lines (90 loc) · 4.11 KB
/
test_reference_validation.py
File metadata and controls
117 lines (90 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
External reference validation.
Validates our implementations against two independent external libraries:
- `rmsd` package (Kabsch algorithm, charnley/rmsd)
- `scipy.spatial.transform` (Horn-equivalent via Rotation.align_vectors)
"""
import numpy as np
import pytest
import rmsd as rmsd_lib
from adapters import FrameworkAdapter, frameworks
from scipy.spatial.transform import Rotation
from kabsch_horn import numpy as kabsch_np
def _reference_kabsch_3d(P: np.ndarray, Q: np.ndarray) -> np.ndarray:
"""Kabsch rotation via rmsd package (charnley/rmsd).
rmsd.kabsch(P, Q) returns R s.t. P @ R aligns to Q (column-vector convention).
Our convention: P @ R.T + t = Q, so our R == rmsd_R.T.
We return our R directly by taking rmsd_R.T.
"""
P_c = P - P.mean(0)
Q_c = Q - Q.mean(0)
R_rmsd = rmsd_lib.kabsch(P_c, Q_c)
return R_rmsd.T
def _reference_horn_3d(P: np.ndarray, Q: np.ndarray) -> np.ndarray:
"""Horn-equivalent rotation via scipy.
Rotation.align_vectors(a, b) finds R s.t. R @ b[i] ~= a[i].
Our convention: P @ R.T + t = Q => R @ P[i] ~= Q[i] - t.
So align_vectors(Q_c, P_c) gives our R.
"""
P_c = P - P.mean(0)
Q_c = Q - Q.mean(0)
result = Rotation.align_vectors(Q_c, P_c)
return result[0].as_matrix()
_SEEDS = [42, 123, 7, 99, 314]
class TestReferenceValidation:
@pytest.mark.parametrize("seed", _SEEDS)
@pytest.mark.parametrize("adapter", frameworks)
def test_kabsch_matches_rmsd_package(
self, adapter: FrameworkAdapter, seed: int
) -> None:
"""Our kabsch rotation matches the rmsd package across multiple seeds."""
rng = np.random.default_rng(seed)
P_np = rng.random((20, 3))
Q_np = rng.random((20, 3))
R_ref = _reference_kabsch_3d(P_np, Q_np)
P = adapter.convert_in(P_np)
Q = adapter.convert_in(Q_np)
res = adapter.kabsch(P, Q)
R_ours = adapter.convert_out(res[0])
# 10x: cross-library SVD implementations (LAPACK variants) diverge
# by O(eps * cond(H))
np.testing.assert_allclose(R_ours, R_ref, atol=adapter.atol * 10)
@pytest.mark.parametrize("seed", _SEEDS)
@pytest.mark.parametrize("adapter", frameworks)
def test_horn_matches_scipy(self, adapter: FrameworkAdapter, seed: int) -> None:
"""Our horn rotation matches scipy Rotation.align_vectors across seeds."""
rng = np.random.default_rng(seed)
P_np = rng.random((20, 3))
Q_np = rng.random((20, 3))
R_ref = _reference_horn_3d(P_np, Q_np)
P = adapter.convert_in(P_np)
Q = adapter.convert_in(Q_np)
res = adapter.horn(P, Q)
R_ours = adapter.convert_out(res[0])
# 10x: cross-library comparison; scipy quaternion solver uses different
# internals
np.testing.assert_allclose(R_ours, R_ref, atol=adapter.atol * 10)
@pytest.mark.parametrize("seed", _SEEDS)
def test_umeyama_rotation_matches_rmsd_reference(self, seed: int) -> None:
"""Umeyama rotation component matches rmsd kabsch rotation across seeds."""
rng = np.random.default_rng(seed)
P_np = rng.random((20, 3))
Q_np = rng.random((20, 3))
R_ref = _reference_kabsch_3d(P_np, Q_np)
R_ours, _, _, _ = kabsch_np.kabsch_umeyama(P_np, Q_np)
# Cross-library: our SVD-based rotation vs rmsd package's LAPACK SVD;
# both float64 on well-conditioned 20x3 clouds
np.testing.assert_allclose(R_ours, R_ref, atol=1e-8)
@pytest.mark.parametrize("seed", _SEEDS)
def test_rmsd_value_matches_rmsd_package(self, seed: int) -> None:
"""Our kabsch RMSD scalar matches rmsd.kabsch_rmsd across multiple seeds."""
rng = np.random.default_rng(seed)
P_np = rng.random((20, 3))
Q_np = rng.random((20, 3))
P_c = P_np - P_np.mean(0)
Q_c = Q_np - Q_np.mean(0)
rmsd_ref = rmsd_lib.kabsch_rmsd(P_c, Q_c)
_, _, rmsd_ours = kabsch_np.kabsch(P_np, Q_np)
# Cross-library: our RMSD vs rmsd package's kabsch_rmsd;
# both float64 on well-conditioned 20x3 clouds
assert float(rmsd_ours) == pytest.approx(float(rmsd_ref), abs=1e-8)