Skip to content

Commit d0e0317

Browse files
authored
Update temp_size metadata to int64 (#289)
1 parent 3ddb79d commit d0e0317

File tree

7 files changed

+183
-354
lines changed

7 files changed

+183
-354
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ apis/python/**/*.egg-info/
6161
apis/python/.installed.cfg
6262
apis/python/*.egg
6363
apis/python/MANIFEST
64+
src/config.h
6465

6566
# PyInstaller
6667
# Usually these files are written by a python script from a template

apis/python/test/test_type_erased_module.py

Lines changed: 1 addition & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -228,20 +228,6 @@ def test_construct_IndexIVFFlat():
228228
assert a.id_type_string() == "int64"
229229
assert a.px_type_string() == "uint64"
230230

231-
# TODO: Create some indexes with other than uint64 id_type and test them
232-
a = vspy.IndexIVFFlat(ctx, siftsmall_group_uri)
233-
assert a.feature_type_string() == "float32"
234-
assert a.id_type_string() == "uint64"
235-
assert a.px_type_string() == "uint64"
236-
assert a.dimension() == 128
237-
238-
a = vspy.IndexIVFFlat(ctx, bigann10k_group_uri)
239-
assert a.feature_type_string() == "uint8"
240-
assert a.id_type_string() == "uint64"
241-
assert a.px_type_string() == "uint64"
242-
assert a.dimension() == 128
243-
244-
245231
def test_inplace_build_infinite_query_IndexIVFFlat():
246232
k_nn = 10
247233
nprobe = 32
@@ -272,108 +258,4 @@ def test_inplace_build_infinite_query_IndexIVFFlat():
272258
if nprobe == 8:
273259
assert recall > 0.925
274260
if nprobe == 32:
275-
assert recall >= 0.999
276-
277-
278-
def test_read_index_and_infinite_query():
279-
k_nn = 10
280-
nprobe = 32
281-
282-
for nprobe in [8, 32]:
283-
a = vspy.IndexIVFFlat(ctx, siftsmall_group_uri)
284-
assert a.feature_type_string() == "float32"
285-
assert a.id_type_string() == "uint64"
286-
assert a.px_type_string() == "uint64"
287-
assert a.dimension() == 128
288-
289-
query_set = vspy.FeatureVectorArray(ctx, siftsmall_query_uri)
290-
assert query_set.feature_type_string() == "float32"
291-
292-
groundtruth_set = vspy.FeatureVectorArray(ctx, siftsmall_groundtruth_uri)
293-
assert groundtruth_set.feature_type_string() == "uint64"
294-
295-
s, t = a.query_infinite_ram(query_set, k_nn, nprobe)
296-
297-
intersections = vspy.count_intersections(t, groundtruth_set, k_nn)
298-
299-
nt = np.double(t.num_vectors()) * np.double(k_nn)
300-
recall = intersections / nt
301-
302-
logging.info(f"nprobe = {nprobe}, recall={recall}")
303-
304-
if nprobe == 8:
305-
assert recall > 0.925
306-
if nprobe == 32:
307-
assert recall == 1.0
308-
309-
310-
def test_read_index_and_finite_query_default_upper_bound():
311-
k_nn = 10
312-
nprobe = 32
313-
314-
for nprobe in [8, 32]:
315-
a = vspy.IndexIVFFlat(ctx, siftsmall_group_uri)
316-
query_set = vspy.FeatureVectorArray(ctx, siftsmall_query_uri)
317-
groundtruth_set = vspy.FeatureVectorArray(ctx, siftsmall_groundtruth_uri)
318-
319-
s, t = a.query_finite_ram(query_set, k_nn, nprobe)
320-
321-
intersections = vspy.count_intersections(t, groundtruth_set, k_nn)
322-
323-
nt = np.double(t.num_vectors()) * np.double(k_nn)
324-
recall = intersections / nt
325-
326-
logging.info(f"nprobe = {nprobe}, recall={recall}")
327-
328-
if nprobe == 8:
329-
assert recall > 0.925
330-
if nprobe == 32:
331-
assert recall == 1.0
332-
333-
334-
def test_read_index_and_finite_query_0_upper_bound():
335-
k_nn = 10
336-
nprobe = 32
337-
338-
for nprobe in [8, 32]:
339-
a = vspy.IndexIVFFlat(ctx, siftsmall_group_uri)
340-
query_set = vspy.FeatureVectorArray(ctx, siftsmall_query_uri)
341-
groundtruth_set = vspy.FeatureVectorArray(ctx, siftsmall_groundtruth_uri)
342-
343-
s, t = a.query_finite_ram(query_set, k_nn, nprobe, 0)
344-
345-
intersections = vspy.count_intersections(t, groundtruth_set, k_nn)
346-
347-
nt = np.double(t.num_vectors()) * np.double(k_nn)
348-
recall = intersections / nt
349-
350-
logging.info(f"nprobe = {nprobe}, recall={recall}")
351-
352-
if nprobe == 8:
353-
assert recall > 0.925
354-
if nprobe == 32:
355-
assert recall == 1.0
356-
357-
358-
def test_read_index_and_finite_query_1000_upper_bound():
359-
k_nn = 10
360-
nprobe = 32
361-
362-
for nprobe in [8, 32]:
363-
a = vspy.IndexIVFFlat(ctx, siftsmall_group_uri)
364-
query_set = vspy.FeatureVectorArray(ctx, siftsmall_query_uri)
365-
groundtruth_set = vspy.FeatureVectorArray(ctx, siftsmall_groundtruth_uri)
366-
367-
s, t = a.query_finite_ram(query_set, k_nn, nprobe, 1000)
368-
369-
intersections = vspy.count_intersections(t, groundtruth_set, k_nn)
370-
371-
nt = np.double(t.num_vectors()) * np.double(k_nn)
372-
recall = intersections / nt
373-
374-
logging.info(f"nprobe = {nprobe}, recall={recall}")
375-
376-
if nprobe == 8:
377-
assert recall > 0.925
378-
if nprobe == 32:
379-
assert recall == 1.0
261+
assert recall >= 0.999

src/include/index/index_metadata.h

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
* "index_type", // "FLAT", "IVF_FLAT", "Vamana"
3737
* "ingestion_timestamps", // (json) list
3838
* "storage_version", // "0.3"
39-
* "temp_size",
39+
* "temp_size", // TILEDB_INT64 or TILEDB_FLOAT64
4040
*
4141
* "feature_datatype", // TILEDB_UINT32
4242
* "id_datatype", // TILEDB_UINT32
@@ -89,7 +89,7 @@ class base_index_metadata {
8989
std::vector<base_sizes_type> base_sizes_;
9090

9191
/** Record size of temp data */
92-
uint64_t temp_size_{0};
92+
int64_t temp_size_{0};
9393

9494
uint32_t dimension_{0};
9595

@@ -127,7 +127,7 @@ class base_index_metadata {
127127
using metadata_arithmetic_check_type =
128128
std::tuple<std::string, void*, tiledb_datatype_t, bool>;
129129
std::vector<metadata_arithmetic_check_type> metadata_arithmetic_checks{
130-
{"temp_size", &temp_size_, TILEDB_UINT64, true},
130+
{"temp_size", &temp_size_, TILEDB_INT64, true},
131131
//{"index_kind",
132132
// nstatic_cast<IndexMetadata*>(this)->index_kind_,
133133
// TILEDB_UINT64,
@@ -213,14 +213,14 @@ class base_index_metadata {
213213

214214
// Handle temp_size as a special case for now
215215
if (name == "temp_size") {
216-
if (v_type == TILEDB_UINT64) {
217-
*static_cast<uint64_t*>(value) = *static_cast<const uint64_t*>(v);
216+
if (v_type == TILEDB_INT64) {
217+
*static_cast<int64_t*>(value) = *static_cast<const int64_t*>(v);
218218
} else if (v_type == TILEDB_FLOAT64) {
219-
*static_cast<uint64_t*>(value) =
220-
static_cast<uint64_t>(*static_cast<const double*>(v));
219+
*static_cast<int64_t*>(value) =
220+
static_cast<int64_t>(*static_cast<const double*>(v));
221221
} else {
222222
throw std::runtime_error(
223-
"temp_size must be a uint64_t or float64 not " +
223+
"temp_size must be a int64_t or float64 not " +
224224
tiledb::impl::type_to_str(v_type));
225225
}
226226
return;
@@ -238,6 +238,9 @@ class base_index_metadata {
238238
case TILEDB_FLOAT32:
239239
*static_cast<float*>(value) = *static_cast<const float*>(v);
240240
break;
241+
case TILEDB_INT64:
242+
*static_cast<int64_t*>(value) = *static_cast<const int64_t*>(v);
243+
break;
241244
case TILEDB_UINT64:
242245
*static_cast<uint64_t*>(value) = *static_cast<const uint64_t*>(v);
243246
break;
@@ -288,13 +291,13 @@ class base_index_metadata {
288291
throw std::runtime_error("Missing metadata: temp_size");
289292
}
290293
read_group.get_metadata("temp_size", &v_type, &v_num, &v);
291-
if (v_type == TILEDB_UINT64) {
292-
temp_size_ = *static_cast<const uint64_t*>(v);
294+
if (v_type == TILEDB_INT64) {
295+
temp_size_ = *static_cast<const int64_t*>(v);
293296
} else if (v_type == TILEDB_FLOAT64) {
294-
temp_size_ = static_cast<uint64_t>(*static_cast<const double*>(v));
297+
temp_size_ = static_cast<int64_t>(*static_cast<const double*>(v));
295298
} else {
296299
throw std::runtime_error(
297-
"temp_size must be a uint64_t or float64 not " +
300+
"temp_size must be a int64_t or float64 not " +
298301
tiledb::impl::type_to_str(v_type));
299302
}
300303

@@ -418,6 +421,11 @@ class base_index_metadata {
418421
if (*static_cast<float*>(value) != *static_cast<float*>(rhs_value)) {
419422
return false;
420423
}
424+
case TILEDB_INT64:
425+
if (*static_cast<int64_t*>(value) !=
426+
*static_cast<int64_t*>(rhs_value)) {
427+
return false;
428+
}
421429
case TILEDB_UINT64:
422430
if (*static_cast<uint64_t*>(value) !=
423431
*static_cast<uint64_t*>(rhs_value)) {
@@ -519,6 +527,10 @@ class base_index_metadata {
519527
case TILEDB_FLOAT32:
520528
std::cout << name << ": " << *static_cast<float*>(value) << std::endl;
521529
break;
530+
case TILEDB_INT64:
531+
std::cout << name << ": " << *static_cast<int64_t*>(value)
532+
<< std::endl;
533+
break;
522534
case TILEDB_UINT64:
523535
std::cout << name << ": " << *static_cast<uint64_t*>(value)
524536
<< std::endl;

src/include/test/unit_ivf_flat_group.cc

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -92,44 +92,6 @@ struct dummy_index {
9292
}
9393
};
9494

95-
TEST_CASE("ivf_flat_group: member type", "[ivf_flat_group") {
96-
tiledb::Context ctx;
97-
98-
auto x = ivf_flat_index_group(dummy_index{}, ctx, sift_group_uri);
99-
}
100-
101-
TEST_CASE("ivf_flat_group: constructor", "[ivf_flat_group]") {
102-
tiledb::Context ctx;
103-
104-
auto foo = dummy_index{};
105-
auto n = foo.dimension();
106-
std::reference_wrapper<const dummy_index> bar = foo;
107-
auto m = bar.get().dimension();
108-
109-
auto x = ivf_flat_index_group(dummy_index{}, ctx, sift_group_uri);
110-
auto y = ivf_flat_index_group(dummy_index{}, ctx, sift_group_uri);
111-
}
112-
113-
TEST_CASE("ivf_flat_group: default constructor", "[ivf_flat_group]") {
114-
tiledb::Context ctx;
115-
auto x = ivf_flat_index_group(dummy_index{}, ctx, sift_group_uri);
116-
x.dump("Default constructor");
117-
}
118-
119-
TEST_CASE("ivf_flat_group: read constructor", "[ivf_flat_group]") {
120-
tiledb::Context ctx;
121-
auto x =
122-
ivf_flat_index_group(dummy_index{}, ctx, sift_group_uri, TILEDB_READ);
123-
x.dump("Read constructor");
124-
}
125-
126-
TEST_CASE("ivf_flat_group: read constructor with version", "[ivf_flat_group]") {
127-
tiledb::Context ctx;
128-
auto x = ivf_flat_index_group(
129-
dummy_index{}, ctx, sift_group_uri, TILEDB_READ, 0, "0.3");
130-
x.dump("Read constructor with version");
131-
}
132-
13395
// The catch2 check for exception doesn't seem to be working correctly
13496
// @todo Fix this
13597
#if 0

0 commit comments

Comments
 (0)