Skip to content

Commit f57684b

Browse files
authored
Add more Python and C++ unit tests around metadata (#281)
1 parent c307438 commit f57684b

File tree

6 files changed

+134
-23
lines changed

6 files changed

+134
-23
lines changed

apis/python/test/test_index.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import pytest
33
from array_paths import *
44
from common import *
5+
import json
56

67
import tiledb.vector_search.index as ind
78
from tiledb.vector_search import Index
89
from tiledb.vector_search import flat_index
910
from tiledb.vector_search import ivf_flat_index
11+
from tiledb.vector_search.index import create_metadata
12+
from tiledb.vector_search.index import DATASET_TYPE
1013
from tiledb.vector_search.flat_index import FlatIndex
1114
from tiledb.vector_search.ingestion import ingest
1215
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
@@ -18,18 +21,49 @@ def query_and_check(index, queries, k, expected, **kwargs):
1821
result_d, result_i = index.query(queries, k=k, **kwargs)
1922
assert expected.issubset(set(result_i[0]))
2023

24+
def check_default_metadata(uri, expected_vector_type, expected_storage_version, expected_index_type):
25+
group = tiledb.Group(uri, "r", ctx=tiledb.Ctx(None))
26+
assert "dataset_type" in group.meta
27+
assert group.meta["dataset_type"] == DATASET_TYPE
28+
assert type(group.meta["dataset_type"]) == str
29+
30+
assert "dtype" in group.meta
31+
assert group.meta["dtype"] == np.dtype(expected_vector_type).name
32+
assert type(group.meta["dtype"]) == str
33+
34+
assert "storage_version" in group.meta
35+
assert group.meta["storage_version"] == expected_storage_version
36+
assert type(group.meta["storage_version"]) == str
37+
38+
assert "index_type" in group.meta
39+
assert group.meta["index_type"] == expected_index_type
40+
assert type(group.meta["index_type"]) == str
41+
42+
assert "base_sizes" in group.meta
43+
assert group.meta["base_sizes"] == json.dumps([0])
44+
assert type(group.meta["base_sizes"]) == str
45+
46+
assert "ingestion_timestamps" in group.meta
47+
assert group.meta["ingestion_timestamps"] == json.dumps([0])
48+
assert type(group.meta["ingestion_timestamps"]) == str
49+
50+
assert "has_updates" in group.meta
51+
assert group.meta["has_updates"] == False
52+
assert type(group.meta["has_updates"]) == np.int64
2153

2254
def test_flat_index(tmp_path):
2355
uri = os.path.join(tmp_path, "array")
24-
index = flat_index.create(uri=uri, dimensions=3, vector_type=np.dtype(np.uint8))
56+
vector_type = np.dtype(np.uint8)
57+
index = flat_index.create(uri=uri, dimensions=3, vector_type=vector_type)
2558
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {ind.MAX_UINT64})
59+
check_default_metadata(uri, vector_type, STORAGE_VERSION, "FLAT")
2660

2761
update_vectors = np.empty([5], dtype=object)
28-
update_vectors[0] = np.array([0, 0, 0], dtype=np.dtype(np.uint8))
29-
update_vectors[1] = np.array([1, 1, 1], dtype=np.dtype(np.uint8))
30-
update_vectors[2] = np.array([2, 2, 2], dtype=np.dtype(np.uint8))
31-
update_vectors[3] = np.array([3, 3, 3], dtype=np.dtype(np.uint8))
32-
update_vectors[4] = np.array([4, 4, 4], dtype=np.dtype(np.uint8))
62+
update_vectors[0] = np.array([0, 0, 0], dtype=vector_type)
63+
update_vectors[1] = np.array([1, 1, 1], dtype=vector_type)
64+
update_vectors[2] = np.array([2, 2, 2], dtype=vector_type)
65+
update_vectors[3] = np.array([3, 3, 3], dtype=vector_type)
66+
update_vectors[4] = np.array([4, 4, 4], dtype=vector_type)
3367
index.update_batch(vectors=update_vectors, external_ids=np.array([0, 1, 2, 3, 4]))
3468
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {1, 2, 3})
3569

@@ -43,8 +77,8 @@ def test_flat_index(tmp_path):
4377
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {0, 2, 4})
4478

4579
update_vectors = np.empty([2], dtype=object)
46-
update_vectors[0] = np.array([1, 1, 1], dtype=np.dtype(np.uint8))
47-
update_vectors[1] = np.array([3, 3, 3], dtype=np.dtype(np.uint8))
80+
update_vectors[0] = np.array([1, 1, 1], dtype=vector_type)
81+
update_vectors[1] = np.array([3, 3, 3], dtype=vector_type)
4882
index.update_batch(vectors=update_vectors, external_ids=np.array([1, 3]))
4983
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {1, 2, 3})
5084

@@ -61,9 +95,9 @@ def test_flat_index(tmp_path):
6195
def test_ivf_flat_index(tmp_path):
6296
partitions = 10
6397
uri = os.path.join(tmp_path, "array")
64-
98+
vector_type = np.dtype(np.uint8)
6599
index = ivf_flat_index.create(
66-
uri=uri, dimensions=3, vector_type=np.dtype(np.uint8), partitions=partitions
100+
uri=uri, dimensions=3, vector_type=vector_type, partitions=partitions
67101
)
68102
query_and_check(
69103
index,
@@ -72,13 +106,14 @@ def test_ivf_flat_index(tmp_path):
72106
{ind.MAX_UINT64},
73107
nprobe=partitions,
74108
)
109+
check_default_metadata(uri, vector_type, STORAGE_VERSION, "IVF_FLAT")
75110

76111
update_vectors = np.empty([5], dtype=object)
77-
update_vectors[0] = np.array([0, 0, 0], dtype=np.dtype(np.uint8))
78-
update_vectors[1] = np.array([1, 1, 1], dtype=np.dtype(np.uint8))
79-
update_vectors[2] = np.array([2, 2, 2], dtype=np.dtype(np.uint8))
80-
update_vectors[3] = np.array([3, 3, 3], dtype=np.dtype(np.uint8))
81-
update_vectors[4] = np.array([4, 4, 4], dtype=np.dtype(np.uint8))
112+
update_vectors[0] = np.array([0, 0, 0], dtype=vector_type)
113+
update_vectors[1] = np.array([1, 1, 1], dtype=vector_type)
114+
update_vectors[2] = np.array([2, 2, 2], dtype=vector_type)
115+
update_vectors[3] = np.array([3, 3, 3], dtype=vector_type)
116+
update_vectors[4] = np.array([4, 4, 4], dtype=vector_type)
82117
index.update_batch(vectors=update_vectors, external_ids=np.array([0, 1, 2, 3, 4]))
83118

84119
query_and_check(
@@ -102,8 +137,8 @@ def test_ivf_flat_index(tmp_path):
102137
)
103138

104139
update_vectors = np.empty([2], dtype=object)
105-
update_vectors[0] = np.array([1, 1, 1], dtype=np.dtype(np.uint8))
106-
update_vectors[1] = np.array([3, 3, 3], dtype=np.dtype(np.uint8))
140+
update_vectors[0] = np.array([1, 1, 1], dtype=vector_type)
141+
update_vectors[1] = np.array([3, 3, 3], dtype=vector_type)
107142
index.update_batch(vectors=update_vectors, external_ids=np.array([1, 3]))
108143
query_and_check(
109144
index, np.array([[2, 2, 2]], dtype=np.float32), 3, {1, 2, 3}, nprobe=partitions
@@ -251,3 +286,17 @@ def test_index_with_incorrect_num_of_query_columns_in_single_vector_query(tmp_pa
251286
# TODO: This also throws a TypeError for incorrect dimension
252287
with pytest.raises(TypeError):
253288
index.query(np.array([1, 1, 1], dtype=np.float32), k=3)
289+
290+
def test_create_metadata(tmp_path):
291+
uri = os.path.join(tmp_path, "array")
292+
293+
# Create the metadata at the specified URI.
294+
dimensions = 3
295+
vector_type: np.dtype = np.dtype(np.uint8)
296+
index_type: str = "IVF_FLAT"
297+
storage_version: str = STORAGE_VERSION
298+
group_exists: bool = False
299+
create_metadata(uri, dimensions, vector_type, index_type, storage_version, group_exists)
300+
301+
# Check it contains the default metadata.
302+
check_default_metadata(uri, vector_type, storage_version, index_type)

src/include/index/vamana_index.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@
3939
#include <queue>
4040
#include <unordered_set>
4141

42+
#include "detail/graph/adj_list.h"
43+
#include "detail/graph/graph_utils.h"
44+
#include "detail/linalg/vector.h"
45+
#include "index/vamana_group.h"
4246
#include "scoring.h"
4347
#include "stats.h"
4448
#include "utils/fixed_min_heap.h"
4549
#include "utils/print_types.h"
4650

47-
#include "detail/graph/adj_list.h"
48-
#include "detail/graph/graph_utils.h"
49-
#include "index/vamana_group.h"
50-
5151
#include <tiledb/tiledb>
5252

5353
#include <tiledb/group_experimental.h>

src/include/test/unit_api_ivf_flat_index.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ TEST_CASE(
201201
auto nt = num_vectors(t);
202202
auto recall = ((double)intersections) / ((double)nt * k_nn);
203203
if (nprobe == 32) {
204-
CHECK(recall == 1.0);
204+
CHECK(std::abs(recall - 1.0) <= 1e-3);
205205
} else if (nprobe == 8) {
206206
CHECK(recall > 0.925);
207207
}
@@ -305,7 +305,7 @@ TEST_CASE(
305305
CHECK(nt == nv);
306306
auto recall = ((double)intersections_a) / ((double)nt * k_nn);
307307
if (nprobe == 32) {
308-
CHECK(recall == 1.0);
308+
CHECK(std::abs(recall - 1.0) <= 1e-3);
309309
} else if (nprobe == 8) {
310310
CHECK(recall > 0.925);
311311
}

src/include/test/unit_backwards_compatibility.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ TEST_CASE(
5353
"backwards_compatibility: test_query_old_indices",
5454
"[backwards_compatibility]") {
5555
tiledb::Context ctx;
56+
tiledb::Config cfg;
57+
5658
std::string datasets_path = backwards_compatibility_root / "data";
5759
auto base =
5860
read_bin_local<siftsmall_feature_type>(ctx, siftmicro_inputs_file);
@@ -93,8 +95,11 @@ TEST_CASE(
9395
index_uri.find("0.0.17") != std::string::npos) {
9496
continue;
9597
}
98+
99+
auto read_group = tiledb::Group(ctx, index_uri, TILEDB_READ, cfg);
96100
std::vector<float> expected_distances(query_indices.size(), 0.0);
97101
if (index_uri.find("ivf_flat") != std::string::npos) {
102+
// First check that we can query the index.
98103
auto index = IndexIVFFlat(ctx, index_uri);
99104
auto&& [scores, ids] =
100105
index.query_infinite_ram(queries_feature_vector_array, 1, 10);
@@ -111,6 +116,10 @@ TEST_CASE(
111116
CHECK(ids_span[0][i] == query_indices[i]);
112117
CHECK(scores_span[0][i] == 0);
113118
}
119+
120+
// Next check that we can load the metadata.
121+
auto metadata = ivf_flat_index_metadata();
122+
metadata.load_metadata(read_group);
114123
} else if (index_uri.find("flat") != std::string::npos) {
115124
// TODO(paris): Fix flat_l2_index and re-enable. Right now it just tries
116125
// to load the URI as a tdbMatrix.

src/include/test/unit_ivf_flat_metadata.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include <string>
3636

3737
#include "array_defs.h"
38+
#include "detail/linalg/tdb_matrix.h"
39+
#include "index/ivf_flat_index.h"
3840
#include "index/ivf_flat_metadata.h"
3941

4042
TEST_CASE("ivf_flat_metadata: test test", "[ivf_flat_metadata]") {
@@ -55,6 +57,31 @@ TEST_CASE(
5557
y.dump();
5658
}
5759

60+
TEST_CASE(
61+
"ivf_flat_metadata: load metadata from index", "[ivf_flat_metadata]") {
62+
tiledb::Context ctx;
63+
tiledb::Config cfg;
64+
65+
std::string uri =
66+
(std::filesystem::temp_directory_path() / "tmp_ivf_index").string();
67+
auto training_vectors =
68+
tdbColMajorPreLoadMatrix<float>(ctx, siftsmall_inputs_uri, 0);
69+
auto idx = ivf_flat_index<float, uint32_t, uint32_t>(100, 1);
70+
idx.train(training_vectors, kmeans_init::kmeanspp);
71+
idx.add(training_vectors);
72+
idx.write_index(ctx, uri, true);
73+
74+
auto read_group = tiledb::Group(ctx, uri, TILEDB_READ, cfg);
75+
76+
auto x = ivf_flat_index_metadata();
77+
x.load_metadata(read_group);
78+
79+
// Compare two constructed objects.
80+
ivf_flat_index_metadata y;
81+
y.load_metadata(read_group);
82+
CHECK(x.compare_metadata(y));
83+
}
84+
5885
TEST_CASE("ivf_flat_metadata: open group", "[ivf_flat_metadata]") {
5986
tiledb::Context ctx;
6087
tiledb::Config cfg;

src/include/test/unit_vamana_metadata.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
#include <tiledb/tiledb>
3434
#include <vector>
3535
#include "array_defs.h"
36+
#include "detail/linalg/tdb_matrix.h"
37+
#include "index/vamana_index.h"
3638
#include "index/vamana_metadata.h"
3739

3840
std::vector<std::tuple<std::string, std::string>> expected_str{
@@ -88,6 +90,30 @@ TEST_CASE("vamana_metadata: default constructor dump", "[vamana_metadata]") {
8890
}
8991
}
9092

93+
TEST_CASE("vamana_metadata: load metadata from index", "[vamana_metadata]") {
94+
tiledb::Context ctx;
95+
tiledb::Config cfg;
96+
97+
std::string uri =
98+
(std::filesystem::temp_directory_path() / "tmp_vamana_index").string();
99+
auto training_vectors =
100+
tdbColMajorPreLoadMatrix<float>(ctx, siftsmall_inputs_uri);
101+
auto idx =
102+
vamana_index<float, uint64_t>(num_vectors(training_vectors), 20, 40, 30);
103+
idx.train(training_vectors);
104+
idx.add(training_vectors);
105+
idx.write_index(ctx, uri, true);
106+
107+
auto read_group = tiledb::Group(ctx, uri, TILEDB_READ, cfg);
108+
auto x = vamana_index_metadata();
109+
x.load_metadata(read_group);
110+
111+
// Compare two constructed objects.
112+
vamana_index_metadata y;
113+
y.load_metadata(read_group);
114+
CHECK(x.compare_metadata(y));
115+
}
116+
91117
// @todo More vamana groups (from "real" data) to test with
92118
TEST_CASE("vamana_metadata: open group", "[vamana_metadata]") {
93119
bool debug = false;

0 commit comments

Comments
 (0)