|
1 | | -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. |
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 | # |
4 | 4 |
|
|
7 | 7 | from pylibraft.common import device_ndarray |
8 | 8 |
|
9 | 9 | from cuvs.neighbors import brute_force, cagra, ivf_flat, ivf_pq |
10 | | -from cuvs.tests.ann_utils import generate_data |
| 10 | +from cuvs.tests.ann_utils import calc_recall, generate_data |
11 | 11 |
|
12 | 12 |
|
13 | 13 | @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.ubyte]) |
@@ -77,5 +77,17 @@ def run_save_load(ann_module, dtype): |
77 | 77 | neighbors2 = neighbors_dev.copy_to_host() |
78 | 78 | dist2 = distance_dev.copy_to_host() |
79 | 79 |
|
80 | | - assert np.all(neighbors == neighbors2) |
81 | 80 | assert np.allclose(dist, dist2, rtol=1e-6) |
| 81 | + |
| 82 | + # Sort the neighbors to avoid ordering issues |
| 83 | + sorted_neighbors = np.argsort(neighbors, axis=-1) |
| 84 | + sorted_neighbors2 = np.argsort(neighbors2, axis=-1) |
| 85 | + neighbors = np.take_along_axis(neighbors, sorted_neighbors, axis=-1) |
| 86 | + neighbors2 = np.take_along_axis(neighbors2, sorted_neighbors2, axis=-1) |
| 87 | + all_match = np.all(neighbors == neighbors2) |
| 88 | + # If the neighbors are not the same, there might be a cutoff between the k |
| 89 | + # and k+1 neighbors at the same distance. |
| 90 | + # Calculate that the recall is at least 99.8% |
| 91 | + if not all_match: |
| 92 | + recall = calc_recall(neighbors, neighbors2) |
| 93 | + assert recall >= 0.998 |
0 commit comments