Skip to content

Commit a4436be

Browse files
Merge pull request #80 from TileDB-Inc/npapa/fix_query
Fix segfault in finite ram queries
2 parents b679c08 + f4f9656 commit a4436be

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

apis/python/test/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,11 @@ def create_array(path: str, data):
181181
tiledb.Array.create(path, schema)
182182
with tiledb.open(path, "w") as A:
183183
A[:] = data
184+
185+
def accuracy(result, gt):
186+
found = 0
187+
total = 0
188+
for i in range(len(result)):
189+
total+=len(result[i])
190+
found+=len(np.intersect1d(result[i], gt[i]))
191+
return found/total

apis/python/test/test_ingestion.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_flat_ingestion_u8(tmp_path):
2222
source_type=source_type,
2323
)
2424
result = np.transpose(index.query(np.transpose(query_vectors), k=k))
25-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
25+
assert accuracy(result, gt_i) > 0.98
2626

2727

2828
def test_flat_ingestion_f32(tmp_path):
@@ -43,7 +43,7 @@ def test_flat_ingestion_f32(tmp_path):
4343
source_type=source_type,
4444
)
4545
result = np.transpose(index.query(np.transpose(query_vectors), k=k))
46-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
46+
assert accuracy(result, gt_i) > 0.98
4747

4848

4949
def test_ivf_flat_ingestion_u8(tmp_path):
@@ -52,13 +52,12 @@ def test_ivf_flat_ingestion_u8(tmp_path):
5252
k = 10
5353
size = 100000
5454
partitions = 100
55-
create_random_dataset_u8(nb=size, d=100, nq=2, k=k, path=dataset_dir)
55+
create_random_dataset_u8(nb=size, d=100, nq=10, k=k, path=dataset_dir)
5656
source_type = "U8BIN"
5757
dtype = np.uint8
5858

5959
query_vectors = get_queries(dataset_dir, dtype=dtype)
6060
gt_i, gt_d = get_groundtruth(dataset_dir, k)
61-
6261
index = ingest(
6362
index_type="IVF_FLAT",
6463
array_uri=array_uri,
@@ -68,15 +67,15 @@ def test_ivf_flat_ingestion_u8(tmp_path):
6867
input_vectors_per_work_item=int(size / 10),
6968
)
7069
result = np.transpose(
71-
index.query(np.transpose(query_vectors), k=k, nprobe=partitions)
70+
index.query(np.transpose(query_vectors), k=k, nprobe=10)
7271
)
73-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
72+
assert accuracy(result, gt_i) > 0.98
7473

75-
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype)
74+
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype, memory_budget=int(size / 10))
7675
result = np.transpose(
7776
index_ram.query(np.transpose(query_vectors), k=k, nprobe=partitions)
7877
)
79-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
78+
assert accuracy(result, gt_i) > 0.98
8079
result = np.transpose(
8180
index_ram.query(
8281
np.transpose(query_vectors),
@@ -85,7 +84,7 @@ def test_ivf_flat_ingestion_u8(tmp_path):
8584
use_nuv_implementation=True,
8685
)
8786
)
88-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
87+
assert accuracy(result, gt_i) > 0.98
8988

9089

9190
def test_ivf_flat_ingestion_f32(tmp_path):
@@ -112,13 +111,13 @@ def test_ivf_flat_ingestion_f32(tmp_path):
112111
result = np.transpose(
113112
index.query(np.transpose(query_vectors), k=k, nprobe=partitions)
114113
)
115-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
114+
assert accuracy(result, gt_i) > 0.98
116115

117-
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype)
116+
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype, memory_budget=int(size / 10))
118117
result = np.transpose(
119118
index_ram.query(np.transpose(query_vectors), k=k, nprobe=partitions)
120119
)
121-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
120+
assert accuracy(result, gt_i) > 0.98
122121
result = np.transpose(
123122
index_ram.query(
124123
np.transpose(query_vectors),
@@ -127,7 +126,7 @@ def test_ivf_flat_ingestion_f32(tmp_path):
127126
use_nuv_implementation=True,
128127
)
129128
)
130-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
129+
assert accuracy(result, gt_i) > 0.98
131130

132131

133132
def test_ivf_flat_ingestion_fvec(tmp_path):
@@ -157,13 +156,13 @@ def test_ivf_flat_ingestion_fvec(tmp_path):
157156
result = np.transpose(
158157
index.query(np.transpose(query_vectors), k=k, nprobe=partitions)
159158
)
160-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
159+
assert accuracy(result, gt_i) > 0.98
161160

162161
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype)
163162
result = np.transpose(
164163
index_ram.query(np.transpose(query_vectors), k=k, nprobe=partitions)
165164
)
166-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
165+
assert accuracy(result, gt_i) > 0.98
167166
result = np.transpose(
168167
index_ram.query(
169168
np.transpose(query_vectors),
@@ -172,4 +171,4 @@ def test_ivf_flat_ingestion_fvec(tmp_path):
172171
use_nuv_implementation=True,
173172
)
174173
)
175-
assert np.array_equal(np.sort(result, axis=1), np.sort(gt_i, axis=1))
174+
assert accuracy(result, gt_i) > 0.98

src/include/detail/ivf/qv.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -609,16 +609,16 @@ auto nuv_query_heap_finite_ram(
609609
_i.start();
610610

611611
size_t parts_per_thread =
612-
(size(active_partitions) + nthreads - 1) / nthreads;
612+
(shuffled_db.num_col_parts() + nthreads - 1) / nthreads;
613613

614614
std::vector<std::future<void>> futs;
615615
futs.reserve(nthreads);
616616

617617
for (size_t n = 0; n < nthreads; ++n) {
618618
auto first_part =
619-
std::min<size_t>(n * parts_per_thread, size(active_partitions));
619+
std::min<size_t>(n * parts_per_thread, shuffled_db.num_col_parts());
620620
auto last_part =
621-
std::min<size_t>((n + 1) * parts_per_thread, size(active_partitions));
621+
std::min<size_t>((n + 1) * parts_per_thread, shuffled_db.num_col_parts());
622622

623623
if (first_part != last_part) {
624624
futs.emplace_back(std::async(
@@ -794,16 +794,16 @@ auto qv_query_heap_finite_ram(
794794

795795
// size_t block_size = (size(active_partitions) + nthreads - 1) / nthreads;
796796
size_t parts_per_thread =
797-
(size(active_partitions) + nthreads - 1) / nthreads;
797+
(shuffled_db.num_col_parts() + nthreads - 1) / nthreads;
798798

799799
std::vector<std::future<void>> futs;
800800
futs.reserve(nthreads);
801801

802802
for (size_t n = 0; n < nthreads; ++n) {
803803
auto first_part =
804-
std::min<size_t>(n * parts_per_thread, size(active_partitions));
804+
std::min<size_t>(n * parts_per_thread, shuffled_db.num_col_parts());
805805
auto last_part =
806-
std::min<size_t>((n + 1) * parts_per_thread, size(active_partitions));
806+
std::min<size_t>((n + 1) * parts_per_thread, shuffled_db.num_col_parts());
807807

808808
if (first_part != last_part) {
809809
futs.emplace_back(

0 commit comments

Comments
 (0)