Skip to content

Commit a39410b

Browse files
author
Nikos Papailiou
committed
Parallelize queries for arrays with updates
1 parent 70f44e7 commit a39410b

File tree

2 files changed

+45
-24
lines changed

2 files changed

+45
-24
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import multiprocessing as mp
12
import numpy as np
23
import sys
34

@@ -37,15 +38,26 @@ def __init__(
3738
self.update_arrays_uri = None
3839
self.index_version = self.group.meta.get("index_version", "")
3940

40-
4141
def query(self, queries: np.ndarray, k, **kwargs):
42+
if self.update_arrays_uri is None:
43+
return self.query_internal(queries, k, **kwargs)
44+
45+
# Query with updates
4246
updated_ids = set(self.read_updated_ids())
4347
retrieval_k = k
4448
if len(updated_ids) > 0:
4549
retrieval_k = 2*k
50+
51+
# Perform the queries in parallel
52+
kwargs["nthreads"] = int(mp.cpu_count()/2)
53+
parent_conn, child_conn = mp.Pipe()
54+
p = mp.Process(
55+
target=Index.query_additions,
56+
args=(child_conn, queries, k, self.dtype, self.update_arrays_uri, int(mp.cpu_count()/2)))
57+
p.start()
4658
internal_results_d, internal_results_i = self.query_internal(queries, retrieval_k, **kwargs)
47-
if self.update_arrays_uri is None:
48-
return internal_results_d[:, 0:k], internal_results_i[:, 0:k]
59+
addition_results_d, addition_results_i = parent_conn.recv()
60+
p.join()
4961

5062
# Filter updated vectors
5163
query_id = 0
@@ -62,7 +74,6 @@ def query(self, queries: np.ndarray, k, **kwargs):
6274
internal_results_i = np.take_along_axis(internal_results_i, sort_index, axis=1)
6375

6476
# Merge update results
65-
addition_results_d, addition_results_i = self.query_additions(queries, k)
6677
if addition_results_d is None:
6778
return internal_results_d[:, 0:k], internal_results_i[:, 0:k]
6879

@@ -84,22 +95,38 @@ def query(self, queries: np.ndarray, k, **kwargs):
8495
results_i = np.take_along_axis(results_i, sort_index, axis=1)
8596
return results_d[:, 0:k], results_i[:, 0:k]
8697

87-
def query_internal(self, queries: np.ndarray, k, **kwargs):
88-
raise NotImplementedError
89-
90-
def query_additions(self, queries: np.ndarray, k):
98+
@staticmethod
99+
def query_additions(conn, queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8):
91100
assert queries.dtype == np.float32
92-
additions_vectors, additions_external_ids = self.read_additions()
101+
additions_vectors, additions_external_ids = Index.read_additions(update_arrays_uri)
93102
if additions_vectors is None:
94103
return None, None
95104
queries_m = array_to_matrix(np.transpose(queries))
96105
d, i = query_vq_heap_pyarray(
97-
array_to_matrix(np.transpose(additions_vectors).astype(self.dtype)),
106+
array_to_matrix(np.transpose(additions_vectors).astype(dtype)),
98107
queries_m,
99108
StdVector_u64(additions_external_ids),
100109
k,
101-
8)
102-
return np.transpose(np.array(d)), np.transpose(np.array(i))
110+
nthreads)
111+
conn.send((np.transpose(np.array(d)), np.transpose(np.array(i))))
112+
conn.close()
113+
114+
@staticmethod
115+
def read_additions(update_arrays_uri) -> (np.ndarray, np.array):
116+
if update_arrays_uri is None:
117+
return None, None
118+
updates_array = tiledb.open(update_arrays_uri, mode="r")
119+
q = updates_array.query(attrs=('vector',), coords=True)
120+
data = q[:]
121+
updates_array.close()
122+
additions_filter = [len(item) > 0 for item in data["vector"]]
123+
if len(data["external_id"][additions_filter]) > 0:
124+
return np.vstack(data["vector"][additions_filter]), data["external_id"][additions_filter]
125+
else:
126+
return None, None
127+
128+
def query_internal(self, queries: np.ndarray, k, **kwargs):
129+
raise NotImplementedError
103130

104131
def update(self, vector: np.array, external_id: np.uint64):
105132
updates_array = self.open_updates_array()
@@ -141,23 +168,13 @@ def consolidate_update_fragments(self):
141168
def get_updates_uri(self):
142169
return self.update_arrays_uri
143170

144-
def read_additions(self) -> (np.ndarray, np.array):
145-
if self.update_arrays_uri is None:
146-
return None, None
147-
updates_array = tiledb.open(self.update_arrays_uri, mode="r")
148-
q = updates_array.query(attrs=('vector',), coords=True)
149-
data = q[:]
150-
additions_filter = [len(item) > 0 for item in data["vector"]]
151-
if len(data["external_id"][additions_filter]) > 0:
152-
return np.vstack(data["vector"][additions_filter]), data["external_id"][additions_filter]
153-
else:
154-
return None, None
155171
def read_updated_ids(self) -> np.array:
156172
if self.update_arrays_uri is None:
157173
return np.array([], np.uint64)
158174
updates_array = tiledb.open(self.update_arrays_uri, mode="r")
159175
q = updates_array.query(attrs=('vector',), coords=True)
160176
data = q[:]
177+
updates_array.close()
161178
return data["external_id"]
162179

163180
def open_updates_array(self):
@@ -197,4 +214,8 @@ def consolidate_updates(self):
197214
updates_uri=self.update_arrays_uri
198215
)
199216
tiledb.Array.delete_array(self.update_arrays_uri)
217+
self.group.close()
218+
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
219+
self.group.remove(self.update_arrays_uri)
220+
self.group.close()
200221
return new_index

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1728,7 +1728,7 @@ def consolidate_and_vacuum(
17281728
index_group_uri: str,
17291729
config: Optional[Mapping[str, Any]] = None,
17301730
):
1731-
group = tiledb.Group(index_group_uri, config=config)
1731+
group = tiledb.Group(index_group_uri)
17321732
if INPUT_VECTORS_ARRAY_NAME in group:
17331733
tiledb.Array.delete_array(group[INPUT_VECTORS_ARRAY_NAME].uri)
17341734
if EXTERNAL_IDS_ARRAY_NAME in group:

0 commit comments

Comments
 (0)