|
12 | 12 | from load_tests import pygsp |
13 | 13 | from load_tests import sp |
14 | 14 | from load_tests import TruncatedSVD |
15 | | -from nose.tools import assert_raises_regex |
16 | | -from nose.tools import assert_warns_regex |
| 15 | + |
| 16 | +import pytest |
17 | 17 | from scipy.sparse.csgraph import shortest_path |
| 18 | + |
18 | 19 | from scipy.spatial.distance import pdist |
19 | 20 | from scipy.spatial.distance import squareform |
20 | 21 |
|
@@ -50,17 +51,17 @@ def test_build_knn_with_sample_idx(): |
50 | 51 |
|
51 | 52 |
|
52 | 53 | def test_duplicate_data(): |
53 | | - with assert_warns_regex( |
| 54 | + with pytest.warns( |
54 | 55 | RuntimeWarning, |
55 | | - r"Detected zero distance between samples ([0-9and,\s]*). Consider removing duplicates to avoid errors in downstream processing.", |
| 56 | + match=r"Detected zero distance between samples ([0-9and,\s]*). Consider removing duplicates to avoid errors in downstream processing.", |
56 | 57 | ): |
57 | 58 | build_graph(np.vstack([data, data[:9]]), n_pca=None, decay=10, thresh=1e-4) |
58 | 59 |
|
59 | 60 |
|
60 | 61 | def test_duplicate_data_many(): |
61 | | - with assert_warns_regex( |
| 62 | + with pytest.warns( |
62 | 63 | RuntimeWarning, |
63 | | - "Detected zero distance between ([0-9]*) pairs of samples. Consider removing duplicates to avoid errors in downstream processing.", |
| 64 | + match=r"Detected zero distance between ([0-9and,\s]*) pairs of samples. Consider removing duplicates to avoid errors in downstream processing.", |
64 | 65 | ): |
65 | 66 | build_graph(np.vstack([data, data[:21]]), n_pca=None, decay=10, thresh=1e-4) |
66 | 67 |
|
@@ -305,8 +306,9 @@ def test_knnmax(): |
305 | 306 | ) |
306 | 307 | assert isinstance(G2, graphtools.graphs.kNNGraph) |
307 | 308 | assert G.N == G2.N |
308 | | - assert np.all(G.dw == G2.dw) |
309 | | - assert (G.W - G2.W).nnz == 0 |
| 309 | + np.testing.assert_allclose(G.dw, G2.dw) |
| 310 | + # Use allclose for sparse matrices to handle floating-point precision |
| 311 | + np.testing.assert_allclose(G.W.toarray(), G2.W.toarray(), rtol=1e-7, atol=1e-10) |
310 | 312 | finally: |
311 | 313 | gg.NUMBA_AVAILABLE = original_numba |
312 | 314 |
|
|
0 commit comments