|
7 | 7 | from common import * |
8 | 8 | from common import load_metadata |
9 | 9 |
|
10 | | -import tiledb.vector_search.index as ind |
11 | 10 | from tiledb.vector_search import Index |
12 | 11 | from tiledb.vector_search import flat_index |
13 | 12 | from tiledb.vector_search import ivf_flat_index |
|
17 | 16 | from tiledb.vector_search.index import create_metadata |
18 | 17 | from tiledb.vector_search.ingestion import ingest |
19 | 18 | from tiledb.vector_search.ivf_flat_index import IVFFlatIndex |
| 19 | +from tiledb.vector_search.utils import MAX_FLOAT32 |
| 20 | +from tiledb.vector_search.utils import MAX_UINT64 |
20 | 21 | from tiledb.vector_search.utils import is_type_erased_index |
21 | 22 | from tiledb.vector_search.utils import load_fvecs |
22 | 23 | from tiledb.vector_search.vamana_index import VamanaIndex |
@@ -78,7 +79,7 @@ def test_flat_index(tmp_path): |
78 | 79 | uri = os.path.join(tmp_path, "array") |
79 | 80 | vector_type = np.dtype(np.uint8) |
80 | 81 | index = flat_index.create(uri=uri, dimensions=3, vector_type=vector_type) |
81 | | - query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {ind.MAX_UINT64}) |
| 82 | + query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {MAX_UINT64}) |
82 | 83 | check_default_metadata(uri, vector_type, STORAGE_VERSION, "FLAT") |
83 | 84 |
|
84 | 85 | update_vectors = np.empty([5], dtype=object) |
@@ -136,7 +137,7 @@ def test_ivf_flat_index(tmp_path): |
136 | 137 | index, |
137 | 138 | np.array([[2, 2, 2]], dtype=np.float32), |
138 | 139 | 3, |
139 | | - {ind.MAX_UINT64}, |
| 140 | + {MAX_UINT64}, |
140 | 141 | nprobe=partitions, |
141 | 142 | ) |
142 | 143 | check_default_metadata(uri, vector_type, STORAGE_VERSION, "IVF_FLAT") |
@@ -221,12 +222,12 @@ def test_vamana_index_simple(tmp_path): |
221 | 222 | # Create the index. |
222 | 223 | index = vamana_index.create(uri=uri, dimensions=dimensions, vector_type=vector_type) |
223 | 224 | assert index.get_dimensions() == dimensions |
224 | | - query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {ind.MAX_UINT64}) |
| 225 | + query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {MAX_UINT64}) |
225 | 226 |
|
226 | 227 | # Open the index. |
227 | 228 | index = VamanaIndex(uri=uri) |
228 | 229 | assert index.get_dimensions() == dimensions |
229 | | - query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {ind.MAX_UINT64}) |
| 230 | + query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {MAX_UINT64}) |
230 | 231 |
|
231 | 232 | vfs = tiledb.VFS() |
232 | 233 | assert vfs.dir_size(uri) > 0 |
@@ -254,11 +255,9 @@ def test_vamana_index(tmp_path): |
254 | 255 | distances, ids = index.query(queries, k=1) |
255 | 256 | assert distances.shape == (1, 1) |
256 | 257 | assert ids.shape == (1, 1) |
257 | | - assert distances[0][0] == ind.MAX_FLOAT_32 |
258 | | - assert ids[0][0] == ind.MAX_UINT64 |
259 | | - query_and_check_distances( |
260 | | - index, queries, 1, [[ind.MAX_FLOAT_32]], [[ind.MAX_UINT64]] |
261 | | - ) |
| 258 | + assert distances[0][0] == MAX_FLOAT32 |
| 259 | + assert ids[0][0] == MAX_UINT64 |
| 260 | + query_and_check_distances(index, queries, 1, [[MAX_FLOAT32]], [[MAX_UINT64]]) |
262 | 261 | check_default_metadata(uri, vector_type, STORAGE_VERSION, "VAMANA") |
263 | 262 |
|
264 | 263 | update_vectors = np.empty([5], dtype=object) |
|
0 commit comments