Skip to content

Commit f564a8d

Browse files
authored
Share max int / float values and index lists in Python (#396)
1 parent bcf7220 commit f564a8d

File tree

10 files changed

+54
-38
lines changed

10 files changed

+54
-38
lines changed

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
from tiledb.vector_search.storage_formats import STORAGE_VERSION
1414
from tiledb.vector_search.storage_formats import storage_formats
1515
from tiledb.vector_search.storage_formats import validate_storage_version
16+
from tiledb.vector_search.utils import MAX_FLOAT32
17+
from tiledb.vector_search.utils import MAX_INT32
18+
from tiledb.vector_search.utils import MAX_UINT64
1619
from tiledb.vector_search.utils import add_to_group
1720

18-
MAX_INT32 = np.iinfo(np.dtype("int32")).max
19-
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
2021
TILE_SIZE_BYTES = 128000000 # 128MB
2122
INDEX_TYPE = "FLAT"
2223

@@ -116,8 +117,8 @@ def query_internal(
116117
# - typecheck queries
117118
# - add all the options and query strategies
118119
if self.size == 0:
119-
return np.full((queries.shape[0], k), index.MAX_FLOAT_32), np.full(
120-
(queries.shape[0], k), index.MAX_UINT64
120+
return np.full((queries.shape[0], k), MAX_FLOAT32), np.full(
121+
(queries.shape[0], k), MAX_UINT64
121122
)
122123

123124
assert queries.dtype == np.float32

apis/python/src/tiledb/vector_search/index.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
from tiledb.vector_search import _tiledbvspy as vspy
88
from tiledb.vector_search.module import *
99
from tiledb.vector_search.storage_formats import storage_formats
10+
from tiledb.vector_search.utils import MAX_FLOAT32
11+
from tiledb.vector_search.utils import MAX_UINT64
1012
from tiledb.vector_search.utils import add_to_group
1113

12-
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
13-
MAX_INT32 = np.iinfo(np.dtype("int32")).max
14-
MAX_FLOAT_32 = np.finfo(np.dtype("float32")).max
1514
DATASET_TYPE = "vector_search"
1615

1716

@@ -190,7 +189,7 @@ def query(self, queries: np.ndarray, k: int, **kwargs):
190189
if self.query_base_array:
191190
return self.query_internal(queries, k, **kwargs)
192191
else:
193-
return np.full((queries.shape[0], k), MAX_FLOAT_32), np.full(
192+
return np.full((queries.shape[0], k), MAX_FLOAT32), np.full(
194193
(queries.shape[0], k), MAX_UINT64
195194
)
196195

@@ -213,7 +212,7 @@ def query(self, queries: np.ndarray, k: int, **kwargs):
213212
queries, retrieval_k, **kwargs
214213
)
215214
else:
216-
internal_results_d = np.full((queries.shape[0], k), MAX_FLOAT_32)
215+
internal_results_d = np.full((queries.shape[0], k), MAX_FLOAT32)
217216
internal_results_i = np.full((queries.shape[0], k), MAX_UINT64)
218217
addition_results_d, addition_results_i, updated_ids = future.result()
219218

@@ -223,7 +222,7 @@ def query(self, queries: np.ndarray, k: int, **kwargs):
223222
res_id = 0
224223
for res in query:
225224
if res in updated_ids:
226-
internal_results_d[query_id, res_id] = MAX_FLOAT_32
225+
internal_results_d[query_id, res_id] = MAX_FLOAT32
227226
internal_results_i[query_id, res_id] = MAX_UINT64
228227
res_id += 1
229228
query_id += 1
@@ -243,7 +242,7 @@ def query(self, queries: np.ndarray, k: int, **kwargs):
243242
addition_results_d[query_id, res_id] == 0
244243
and addition_results_i[query_id, res_id] == 0
245244
):
246-
addition_results_d[query_id, res_id] = MAX_FLOAT_32
245+
addition_results_d[query_id, res_id] = MAX_FLOAT32
247246
addition_results_i[query_id, res_id] = MAX_UINT64
248247
res_id += 1
249248
query_id += 1

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@
3535
from tiledb.vector_search.storage_formats import STORAGE_VERSION
3636
from tiledb.vector_search.storage_formats import storage_formats
3737
from tiledb.vector_search.storage_formats import validate_storage_version
38+
from tiledb.vector_search.utils import MAX_FLOAT32
39+
from tiledb.vector_search.utils import MAX_INT32
40+
from tiledb.vector_search.utils import MAX_UINT64
3841
from tiledb.vector_search.utils import add_to_group
3942

40-
MAX_INT32 = np.iinfo(np.dtype("int32")).max
41-
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
4243
TILE_SIZE_BYTES = 64000000 # 64MB
4344
INDEX_TYPE = "IVF_FLAT"
4445

@@ -215,8 +216,8 @@ def query_internal(
215216
If provided, this is the number of workers to use for the query execution.
216217
"""
217218
if self.size == 0:
218-
return np.full((queries.shape[0], k), index.MAX_FLOAT_32), np.full(
219-
(queries.shape[0], k), index.MAX_UINT64
219+
return np.full((queries.shape[0], k), MAX_FLOAT32), np.full(
220+
(queries.shape[0], k), MAX_UINT64
220221
)
221222

222223
if mode != Mode.BATCH and resources:

apis/python/src/tiledb/vector_search/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import tiledb
88
from tiledb.vector_search import _tiledbvspy as vspy
99

10+
MAX_INT32 = np.iinfo(np.dtype("int32")).max
11+
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
12+
MAX_FLOAT32 = np.finfo(np.dtype("float32")).max
13+
1014

1115
def is_type_erased_index(index_type: str) -> bool:
1216
return index_type == "VAMANA"

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
from tiledb.vector_search.storage_formats import STORAGE_VERSION
2222
from tiledb.vector_search.storage_formats import storage_formats
2323
from tiledb.vector_search.storage_formats import validate_storage_version
24+
from tiledb.vector_search.utils import MAX_FLOAT32
25+
from tiledb.vector_search.utils import MAX_UINT64
2426
from tiledb.vector_search.utils import to_temporal_policy
2527

26-
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
2728
INDEX_TYPE = "VAMANA"
2829

2930

@@ -97,8 +98,8 @@ def query_internal(
9798
"""
9899
warnings.warn("The Vamana index is not yet supported, please use with caution.")
99100
if self.size == 0:
100-
return np.full((queries.shape[0], k), index.MAX_FLOAT_32), np.full(
101-
(queries.shape[0], k), index.MAX_UINT64
101+
return np.full((queries.shape[0], k), MAX_FLOAT32), np.full(
102+
(queries.shape[0], k), MAX_UINT64
102103
)
103104

104105
assert queries.dtype == np.float32

apis/python/test/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,19 @@
88

99
import tiledb
1010
from tiledb.cloud import groups
11+
from tiledb.vector_search.flat_index import FlatIndex
12+
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
1113
from tiledb.vector_search.storage_formats import STORAGE_VERSION
1214
from tiledb.vector_search.storage_formats import storage_formats
15+
from tiledb.vector_search.vamana_index import VamanaIndex
16+
17+
INDEXES = ["FLAT", "IVF_FLAT", "VAMANA"]
18+
INDEX_CLASSES = [FlatIndex, IVFFlatIndex, VamanaIndex]
19+
INDEX_FILES = [
20+
tiledb.vector_search.flat_index,
21+
tiledb.vector_search.ivf_flat_index,
22+
tiledb.vector_search.vamana_index,
23+
]
1324

1425

1526
def xbin_mmap(fname, dtype):

apis/python/test/test_index.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from common import *
88
from common import load_metadata
99

10-
import tiledb.vector_search.index as ind
1110
from tiledb.vector_search import Index
1211
from tiledb.vector_search import flat_index
1312
from tiledb.vector_search import ivf_flat_index
@@ -17,6 +16,8 @@
1716
from tiledb.vector_search.index import create_metadata
1817
from tiledb.vector_search.ingestion import ingest
1918
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
19+
from tiledb.vector_search.utils import MAX_FLOAT32
20+
from tiledb.vector_search.utils import MAX_UINT64
2021
from tiledb.vector_search.utils import is_type_erased_index
2122
from tiledb.vector_search.utils import load_fvecs
2223
from tiledb.vector_search.vamana_index import VamanaIndex
@@ -78,7 +79,7 @@ def test_flat_index(tmp_path):
7879
uri = os.path.join(tmp_path, "array")
7980
vector_type = np.dtype(np.uint8)
8081
index = flat_index.create(uri=uri, dimensions=3, vector_type=vector_type)
81-
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {ind.MAX_UINT64})
82+
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {MAX_UINT64})
8283
check_default_metadata(uri, vector_type, STORAGE_VERSION, "FLAT")
8384

8485
update_vectors = np.empty([5], dtype=object)
@@ -136,7 +137,7 @@ def test_ivf_flat_index(tmp_path):
136137
index,
137138
np.array([[2, 2, 2]], dtype=np.float32),
138139
3,
139-
{ind.MAX_UINT64},
140+
{MAX_UINT64},
140141
nprobe=partitions,
141142
)
142143
check_default_metadata(uri, vector_type, STORAGE_VERSION, "IVF_FLAT")
@@ -221,12 +222,12 @@ def test_vamana_index_simple(tmp_path):
221222
# Create the index.
222223
index = vamana_index.create(uri=uri, dimensions=dimensions, vector_type=vector_type)
223224
assert index.get_dimensions() == dimensions
224-
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {ind.MAX_UINT64})
225+
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {MAX_UINT64})
225226

226227
# Open the index.
227228
index = VamanaIndex(uri=uri)
228229
assert index.get_dimensions() == dimensions
229-
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {ind.MAX_UINT64})
230+
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {MAX_UINT64})
230231

231232
vfs = tiledb.VFS()
232233
assert vfs.dir_size(uri) > 0
@@ -254,11 +255,9 @@ def test_vamana_index(tmp_path):
254255
distances, ids = index.query(queries, k=1)
255256
assert distances.shape == (1, 1)
256257
assert ids.shape == (1, 1)
257-
assert distances[0][0] == ind.MAX_FLOAT_32
258-
assert ids[0][0] == ind.MAX_UINT64
259-
query_and_check_distances(
260-
index, queries, 1, [[ind.MAX_FLOAT_32]], [[ind.MAX_UINT64]]
261-
)
258+
assert distances[0][0] == MAX_FLOAT32
259+
assert ids[0][0] == MAX_UINT64
260+
query_and_check_distances(index, queries, 1, [[MAX_FLOAT32]], [[MAX_UINT64]])
262261
check_default_metadata(uri, vector_type, STORAGE_VERSION, "VAMANA")
263262

264263
update_vectors = np.empty([5], dtype=object)

apis/python/test/test_ingestion.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,13 @@
1515
from tiledb.vector_search.module import array_to_matrix
1616
from tiledb.vector_search.module import kmeans_fit
1717
from tiledb.vector_search.module import kmeans_predict
18+
from tiledb.vector_search.utils import MAX_UINT64
1819
from tiledb.vector_search.utils import is_type_erased_index
1920
from tiledb.vector_search.utils import load_fvecs
2021
from tiledb.vector_search.utils import metadata_to_list
2122
from tiledb.vector_search.vamana_index import VamanaIndex
2223

2324
MINIMUM_ACCURACY = 0.85
24-
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
25-
26-
INDEXES = ["FLAT", "IVF_FLAT", "VAMANA"]
27-
INDEX_CLASSES = [FlatIndex, IVFFlatIndex, VamanaIndex]
28-
INDEX_FILES = [
29-
tiledb.vector_search.flat_index,
30-
tiledb.vector_search.ivf_flat_index,
31-
tiledb.vector_search.vamana_index,
32-
]
3325

3426

3527
def query_and_check_equals(index, queries, expected_result_d, expected_result_i):

apis/python/test/test_object_index.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from tiledb.vector_search.object_readers import ObjectReader
1212

1313
EMBED_DIM = 4
14-
INDEXES = ["FLAT", "IVF_FLAT", "VAMANA"]
1514

1615

1716
# TestEmbedding with vectors of EMBED_DIM size with all values being the id of the object
@@ -236,6 +235,8 @@ def df_filter(row):
236235

237236

238237
def test_object_index(tmp_path):
238+
from common import INDEXES
239+
239240
for index_type in INDEXES:
240241
index_uri = os.path.join(tmp_path, f"object_index_{index_type}")
241242
reader = TestReader(

src/include/test/unit_ivf_pq_index.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,12 @@ TEST_CASE("ivf_index write and read", "[ivf_index]") {
288288
size_t nthreads = 1;
289289

290290
tiledb::Context ctx;
291+
tiledb::VFS vfs(ctx);
291292
std::string ivf_index_uri =
292293
(std::filesystem::temp_directory_path() / "tmp_ivf_index").string();
294+
if (vfs.is_dir(ivf_index_uri)) {
295+
vfs.remove_dir(ivf_index_uri);
296+
}
293297
auto training_set = tdbColMajorMatrix<float>(ctx, siftsmall_inputs_uri, 0);
294298
load(training_set);
295299

@@ -302,6 +306,9 @@ TEST_CASE("ivf_index write and read", "[ivf_index]") {
302306
ivf_index_uri =
303307
(std::filesystem::temp_directory_path() / "second_tmp_ivf_index")
304308
.string();
309+
if (vfs.is_dir(ivf_index_uri)) {
310+
vfs.remove_dir(ivf_index_uri);
311+
}
305312
idx.write_index(ctx, ivf_index_uri);
306313
auto idx2 = ivf_pq_index<float, uint32_t, uint32_t>(ctx, ivf_index_uri);
307314
idx2.read_index_infinite();

0 commit comments

Comments
 (0)