Skip to content

Commit 19e5803

Browse files
Merge pull request #141 from TileDB-Inc/npapa/parallel-update-queries
Parallelize update queries
2 parents 70f44e7 + 01f9448 commit 19e5803

File tree

2 files changed

+82
-50
lines changed

2 files changed

+82
-50
lines changed
Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import concurrent.futures as futures
2+
import os
13
import numpy as np
24
import sys
35

@@ -20,6 +22,7 @@ class Index:
2022
config: Optional[Mapping[str, Any]]
2123
config dictionary, defaults to None
2224
"""
25+
2326
def __init__(
2427
self,
2528
uri: str,
@@ -36,16 +39,28 @@ def __init__(
3639
self.storage_version = self.group.meta.get("storage_version", "0.1")
3740
self.update_arrays_uri = None
3841
self.index_version = self.group.meta.get("index_version", "")
39-
42+
self.thread_executor = futures.ThreadPoolExecutor()
4043

4144
def query(self, queries: np.ndarray, k, **kwargs):
42-
updated_ids = set(self.read_updated_ids())
43-
retrieval_k = k
44-
if len(updated_ids) > 0:
45-
retrieval_k = 2*k
46-
internal_results_d, internal_results_i = self.query_internal(queries, retrieval_k, **kwargs)
4745
if self.update_arrays_uri is None:
48-
return internal_results_d[:, 0:k], internal_results_i[:, 0:k]
46+
return self.query_internal(queries, k, **kwargs)
47+
48+
# Query with updates
49+
# Perform the queries in parallel
50+
retrieval_k = 2 * k
51+
kwargs["nthreads"] = int(os.cpu_count() / 2)
52+
future = self.thread_executor.submit(
53+
Index.query_additions,
54+
queries,
55+
k,
56+
self.dtype,
57+
self.update_arrays_uri,
58+
int(os.cpu_count() / 2),
59+
)
60+
internal_results_d, internal_results_i = self.query_internal(
61+
queries, retrieval_k, **kwargs
62+
)
63+
addition_results_d, addition_results_i, updated_ids = future.result()
4964

5065
# Filter updated vectors
5166
query_id = 0
@@ -62,112 +77,124 @@ def query(self, queries: np.ndarray, k, **kwargs):
6277
internal_results_i = np.take_along_axis(internal_results_i, sort_index, axis=1)
6378

6479
# Merge update results
65-
addition_results_d, addition_results_i = self.query_additions(queries, k)
6680
if addition_results_d is None:
6781
return internal_results_d[:, 0:k], internal_results_i[:, 0:k]
6882

6983
query_id = 0
7084
for query in addition_results_d:
7185
res_id = 0
7286
for res in query:
73-
if addition_results_d[query_id, res_id] == 0 and addition_results_i[query_id, res_id] == 0:
87+
if (
88+
addition_results_d[query_id, res_id] == 0
89+
and addition_results_i[query_id, res_id] == 0
90+
):
7491
addition_results_d[query_id, res_id] = MAX_FLOAT_32
7592
addition_results_i[query_id, res_id] = MAX_UINT64
7693
res_id += 1
7794
query_id += 1
7895

79-
8096
results_d = np.hstack((internal_results_d, addition_results_d))
8197
results_i = np.hstack((internal_results_i, addition_results_i))
8298
sort_index = np.argsort(results_d, axis=1)
8399
results_d = np.take_along_axis(results_d, sort_index, axis=1)
84100
results_i = np.take_along_axis(results_i, sort_index, axis=1)
85101
return results_d[:, 0:k], results_i[:, 0:k]
86102

87-
def query_internal(self, queries: np.ndarray, k, **kwargs):
88-
raise NotImplementedError
89-
90-
def query_additions(self, queries: np.ndarray, k):
103+
@staticmethod
104+
def query_additions(
105+
queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8
106+
):
91107
assert queries.dtype == np.float32
92-
additions_vectors, additions_external_ids = self.read_additions()
108+
additions_vectors, additions_external_ids, updated_ids = Index.read_additions(
109+
update_arrays_uri
110+
)
93111
if additions_vectors is None:
94-
return None, None
112+
return None, None, updated_ids
113+
95114
queries_m = array_to_matrix(np.transpose(queries))
96115
d, i = query_vq_heap_pyarray(
97-
array_to_matrix(np.transpose(additions_vectors).astype(self.dtype)),
116+
array_to_matrix(np.transpose(additions_vectors).astype(dtype)),
98117
queries_m,
99118
StdVector_u64(additions_external_ids),
100119
k,
101-
8)
102-
return np.transpose(np.array(d)), np.transpose(np.array(i))
120+
nthreads,
121+
)
122+
return np.transpose(np.array(d)), np.transpose(np.array(i)), updated_ids
123+
124+
@staticmethod
125+
def read_additions(update_arrays_uri) -> (np.ndarray, np.array):
126+
if update_arrays_uri is None:
127+
return None, None, np.array([], np.uint64)
128+
updates_array = tiledb.open(update_arrays_uri, mode="r")
129+
q = updates_array.query(attrs=("vector",), coords=True)
130+
data = q[:]
131+
updates_array.close()
132+
updated_ids = data["external_id"]
133+
additions_filter = [len(item) > 0 for item in data["vector"]]
134+
if len(data["external_id"][additions_filter]) > 0:
135+
return (
136+
np.vstack(data["vector"][additions_filter]),
137+
data["external_id"][additions_filter],
138+
updated_ids
139+
)
140+
else:
141+
return None, None, updated_ids
142+
143+
def query_internal(self, queries: np.ndarray, k, **kwargs):
144+
raise NotImplementedError
103145

104146
def update(self, vector: np.array, external_id: np.uint64):
105147
updates_array = self.open_updates_array()
106-
vectors = np.empty((1), dtype='O')
148+
vectors = np.empty((1), dtype="O")
107149
vectors[0] = vector
108-
updates_array[external_id] = {'vector': vectors}
150+
updates_array[external_id] = {"vector": vectors}
109151
updates_array.close()
110152
self.consolidate_update_fragments()
111153

112154
def update_batch(self, vectors: np.ndarray, external_ids: np.array):
113155
updates_array = self.open_updates_array()
114-
updates_array[external_ids] = {'vector': vectors}
156+
updates_array[external_ids] = {"vector": vectors}
115157
updates_array.close()
116158
self.consolidate_update_fragments()
117159

118160
def delete(self, external_id: np.uint64):
119161
updates_array = self.open_updates_array()
120-
deletes = np.empty((1), dtype='O')
162+
deletes = np.empty((1), dtype="O")
121163
deletes[0] = np.array([], dtype=self.dtype)
122-
updates_array[external_id] = {'vector': deletes}
164+
updates_array[external_id] = {"vector": deletes}
123165
updates_array.close()
124166
self.consolidate_update_fragments()
125167

126168
def delete_batch(self, external_ids: np.array):
127169
updates_array = self.open_updates_array()
128-
deletes = np.empty((len(external_ids)), dtype='O')
170+
deletes = np.empty((len(external_ids)), dtype="O")
129171
for i in range(len(external_ids)):
130172
deletes[i] = np.array([], dtype=self.dtype)
131-
updates_array[external_ids] = {'vector': deletes}
173+
updates_array[external_ids] = {"vector": deletes}
132174
updates_array.close()
133175
self.consolidate_update_fragments()
134176

135177
def consolidate_update_fragments(self):
136178
fragments_info = tiledb.array_fragments(self.update_arrays_uri)
137-
if(len(fragments_info) > 10):
179+
if len(fragments_info) > 10:
138180
tiledb.consolidate(self.update_arrays_uri)
139181
tiledb.vacuum(self.update_arrays_uri)
140182

141183
def get_updates_uri(self):
142184
return self.update_arrays_uri
143185

144-
def read_additions(self) -> (np.ndarray, np.array):
145-
if self.update_arrays_uri is None:
146-
return None, None
147-
updates_array = tiledb.open(self.update_arrays_uri, mode="r")
148-
q = updates_array.query(attrs=('vector',), coords=True)
149-
data = q[:]
150-
additions_filter = [len(item) > 0 for item in data["vector"]]
151-
if len(data["external_id"][additions_filter]) > 0:
152-
return np.vstack(data["vector"][additions_filter]), data["external_id"][additions_filter]
153-
else:
154-
return None, None
155-
def read_updated_ids(self) -> np.array:
156-
if self.update_arrays_uri is None:
157-
return np.array([], np.uint64)
158-
updates_array = tiledb.open(self.update_arrays_uri, mode="r")
159-
q = updates_array.query(attrs=('vector',), coords=True)
160-
data = q[:]
161-
return data["external_id"]
162-
163186
def open_updates_array(self):
164187
if self.update_arrays_uri is None:
165-
updates_array_name = storage_formats[self.storage_version]["UPDATES_ARRAY_NAME"]
188+
updates_array_name = storage_formats[self.storage_version][
189+
"UPDATES_ARRAY_NAME"
190+
]
166191
updates_array_uri = f"{self.group.uri}/{updates_array_name}"
167192
if tiledb.array_exists(updates_array_uri):
168193
raise RuntimeError(f"Array {updates_array_uri} already exists.")
169194
external_id_dim = tiledb.Dim(
170-
name="external_id", domain=(0, MAX_UINT64-1), dtype=np.dtype(np.uint64)
195+
name="external_id",
196+
domain=(0, MAX_UINT64 - 1),
197+
dtype=np.dtype(np.uint64),
171198
)
172199
dom = tiledb.Domain(external_id_dim)
173200
vector_attr = tiledb.Attr(name="vector", dtype=self.dtype, var=True)
@@ -188,13 +215,18 @@ def open_updates_array(self):
188215

189216
def consolidate_updates(self):
190217
from tiledb.vector_search.ingestion import ingest
218+
191219
new_index = ingest(
192220
index_type=self.index_type,
193221
index_uri=self.uri,
194222
size=self.size,
195223
source_uri=self.db_uri,
196224
external_ids_uri=self.ids_uri,
197-
updates_uri=self.update_arrays_uri
225+
updates_uri=self.update_arrays_uri,
198226
)
199227
tiledb.Array.delete_array(self.update_arrays_uri)
228+
self.group.close()
229+
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
230+
self.group.remove(self.update_arrays_uri)
231+
self.group.close()
200232
return new_index

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1728,7 +1728,7 @@ def consolidate_and_vacuum(
17281728
index_group_uri: str,
17291729
config: Optional[Mapping[str, Any]] = None,
17301730
):
1731-
group = tiledb.Group(index_group_uri, config=config)
1731+
group = tiledb.Group(index_group_uri)
17321732
if INPUT_VECTORS_ARRAY_NAME in group:
17331733
tiledb.Array.delete_array(group[INPUT_VECTORS_ARRAY_NAME].uri)
17341734
if EXTERNAL_IDS_ARRAY_NAME in group:

0 commit comments

Comments
 (0)