Skip to content

Commit 01f9448

Browse files
author
Nikos Papailiou
committed
Use threads instead of processes
1 parent 6b118ef commit 01f9448

File tree

1 file changed

+15
-22
lines changed

1 file changed

+15
-22
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import multiprocessing as mp
1+
import concurrent.futures as futures
2+
import os
23
import numpy as np
34
import sys
45

@@ -38,6 +39,7 @@ def __init__(
3839
self.storage_version = self.group.meta.get("storage_version", "0.1")
3940
self.update_arrays_uri = None
4041
self.index_version = self.group.meta.get("index_version", "")
42+
self.thread_executor = futures.ThreadPoolExecutor()
4143

4244
def query(self, queries: np.ndarray, k, **kwargs):
4345
if self.update_arrays_uri is None:
@@ -46,25 +48,19 @@ def query(self, queries: np.ndarray, k, **kwargs):
4648
# Query with updates
4749
# Perform the queries in parallel
4850
retrieval_k = 2 * k
49-
kwargs["nthreads"] = int(mp.cpu_count() / 2)
50-
parent_conn, child_conn = mp.Pipe()
51-
p = mp.Process(
52-
target=Index.query_additions,
53-
args=(
54-
child_conn,
55-
queries,
56-
k,
57-
self.dtype,
58-
self.update_arrays_uri,
59-
int(mp.cpu_count() / 2),
60-
),
51+
kwargs["nthreads"] = int(os.cpu_count() / 2)
52+
future = self.thread_executor.submit(
53+
Index.query_additions,
54+
queries,
55+
k,
56+
self.dtype,
57+
self.update_arrays_uri,
58+
int(os.cpu_count() / 2),
6159
)
62-
p.start()
6360
internal_results_d, internal_results_i = self.query_internal(
6461
queries, retrieval_k, **kwargs
6562
)
66-
addition_results_d, addition_results_i, updated_ids = parent_conn.recv()
67-
p.join()
63+
addition_results_d, addition_results_i, updated_ids = future.result()
6864

6965
# Filter updated vectors
7066
query_id = 0
@@ -106,16 +102,14 @@ def query(self, queries: np.ndarray, k, **kwargs):
106102

107103
@staticmethod
108104
def query_additions(
109-
conn, queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8
105+
queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8
110106
):
111107
assert queries.dtype == np.float32
112108
additions_vectors, additions_external_ids, updated_ids = Index.read_additions(
113109
update_arrays_uri
114110
)
115111
if additions_vectors is None:
116-
conn.send(None, None, updated_ids)
117-
conn.close()
118-
return
112+
return None, None, updated_ids
119113

120114
queries_m = array_to_matrix(np.transpose(queries))
121115
d, i = query_vq_heap_pyarray(
@@ -125,8 +119,7 @@ def query_additions(
125119
k,
126120
nthreads,
127121
)
128-
conn.send((np.transpose(np.array(d)), np.transpose(np.array(i)), updated_ids))
129-
conn.close()
122+
return np.transpose(np.array(d)), np.transpose(np.array(i)), updated_ids
130123

131124
@staticmethod
132125
def read_additions(update_arrays_uri) -> (np.ndarray, np.array):

0 commit comments

Comments
 (0)