Skip to content

Commit f4f9656

Browse files
author
Nikos Papailiou
committed
Fix tests
1 parent 817dc54 commit f4f9656

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

apis/python/test/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,11 @@ def create_array(path: str, data):
181181
tiledb.Array.create(path, schema)
182182
with tiledb.open(path, "w") as A:
183183
A[:] = data
184+
185+
def accuracy(result, gt):
186+
found = 0
187+
total = 0
188+
for i in range(len(result)):
189+
total+=len(result[i])
190+
found+=len(np.intersect1d(result[i], gt[i]))
191+
return found/total

apis/python/test/test_ingestion.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_flat_ingestion_u8(tmp_path):
2222
source_type=source_type,
2323
)
2424
result = np.transpose(index.query(np.transpose(query_vectors), k=k))
25-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
25+
assert accuracy(result, gt_i) > 0.98
2626

2727

2828
def test_flat_ingestion_f32(tmp_path):
@@ -43,7 +43,7 @@ def test_flat_ingestion_f32(tmp_path):
4343
source_type=source_type,
4444
)
4545
result = np.transpose(index.query(np.transpose(query_vectors), k=k))
46-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
46+
assert accuracy(result, gt_i) > 0.98
4747

4848

4949
def test_ivf_flat_ingestion_u8(tmp_path):
@@ -52,13 +52,12 @@ def test_ivf_flat_ingestion_u8(tmp_path):
5252
k = 10
5353
size = 100000
5454
partitions = 100
55-
create_random_dataset_u8(nb=size, d=100, nq=2, k=k, path=dataset_dir)
55+
create_random_dataset_u8(nb=size, d=100, nq=10, k=k, path=dataset_dir)
5656
source_type = "U8BIN"
5757
dtype = np.uint8
5858

5959
query_vectors = get_queries(dataset_dir, dtype=dtype)
6060
gt_i, gt_d = get_groundtruth(dataset_dir, k)
61-
6261
index = ingest(
6362
index_type="IVF_FLAT",
6463
array_uri=array_uri,
@@ -68,15 +67,15 @@ def test_ivf_flat_ingestion_u8(tmp_path):
6867
input_vectors_per_work_item=int(size / 10),
6968
)
7069
result = np.transpose(
71-
index.query(np.transpose(query_vectors), k=k, nprobe=partitions)
70+
index.query(np.transpose(query_vectors), k=k, nprobe=10)
7271
)
73-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
72+
assert accuracy(result, gt_i) > 0.98
7473

7574
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype, memory_budget=int(size / 10))
7675
result = np.transpose(
7776
index_ram.query(np.transpose(query_vectors), k=k, nprobe=partitions)
7877
)
79-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
78+
assert accuracy(result, gt_i) > 0.98
8079
result = np.transpose(
8180
index_ram.query(
8281
np.transpose(query_vectors),
@@ -85,7 +84,7 @@ def test_ivf_flat_ingestion_u8(tmp_path):
8584
use_nuv_implementation=True,
8685
)
8786
)
88-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
87+
assert accuracy(result, gt_i) > 0.98
8988

9089

9190
def test_ivf_flat_ingestion_f32(tmp_path):
@@ -112,13 +111,13 @@ def test_ivf_flat_ingestion_f32(tmp_path):
112111
result = np.transpose(
113112
index.query(np.transpose(query_vectors), k=k, nprobe=partitions)
114113
)
115-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
114+
assert accuracy(result, gt_i) > 0.98
116115

117116
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype, memory_budget=int(size / 10))
118117
result = np.transpose(
119118
index_ram.query(np.transpose(query_vectors), k=k, nprobe=partitions)
120119
)
121-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
120+
assert accuracy(result, gt_i) > 0.98
122121
result = np.transpose(
123122
index_ram.query(
124123
np.transpose(query_vectors),
@@ -127,7 +126,7 @@ def test_ivf_flat_ingestion_f32(tmp_path):
127126
use_nuv_implementation=True,
128127
)
129128
)
130-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
129+
assert accuracy(result, gt_i) > 0.98
131130

132131

133132
def test_ivf_flat_ingestion_fvec(tmp_path):
@@ -157,13 +156,13 @@ def test_ivf_flat_ingestion_fvec(tmp_path):
157156
result = np.transpose(
158157
index.query(np.transpose(query_vectors), k=k, nprobe=partitions)
159158
)
160-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
159+
assert accuracy(result, gt_i) > 0.98
161160

162-
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype, memory_budget=int(size / 10))
161+
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype)
163162
result = np.transpose(
164163
index_ram.query(np.transpose(query_vectors), k=k, nprobe=partitions)
165164
)
166-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
165+
assert accuracy(result, gt_i) > 0.98
167166
result = np.transpose(
168167
index_ram.query(
169168
np.transpose(query_vectors),
@@ -172,4 +171,4 @@ def test_ivf_flat_ingestion_fvec(tmp_path):
172171
use_nuv_implementation=True,
173172
)
174173
)
175-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
174+
assert accuracy(result, gt_i) > 0.98

0 commit comments

Comments
 (0)