1- import multiprocessing as mp
1+ import concurrent .futures as futures
2+ import os
23import numpy as np
34import sys
45
@@ -38,6 +39,7 @@ def __init__(
3839 self .storage_version = self .group .meta .get ("storage_version" , "0.1" )
3940 self .update_arrays_uri = None
4041 self .index_version = self .group .meta .get ("index_version" , "" )
42+ self .thread_executor = futures .ThreadPoolExecutor ()
4143
4244 def query (self , queries : np .ndarray , k , ** kwargs ):
4345 if self .update_arrays_uri is None :
@@ -46,25 +48,19 @@ def query(self, queries: np.ndarray, k, **kwargs):
4648 # Query with updates
4749 # Perform the queries in parallel
4850 retrieval_k = 2 * k
49- kwargs ["nthreads" ] = int (mp .cpu_count () / 2 )
50- parent_conn , child_conn = mp .Pipe ()
51- p = mp .Process (
52- target = Index .query_additions ,
53- args = (
54- child_conn ,
55- queries ,
56- k ,
57- self .dtype ,
58- self .update_arrays_uri ,
59- int (mp .cpu_count () / 2 ),
60- ),
51+ kwargs ["nthreads" ] = int (os .cpu_count () / 2 )
52+ future = self .thread_executor .submit (
53+ Index .query_additions ,
54+ queries ,
55+ k ,
56+ self .dtype ,
57+ self .update_arrays_uri ,
58+ int (os .cpu_count () / 2 ),
6159 )
62- p .start ()
6360 internal_results_d , internal_results_i = self .query_internal (
6461 queries , retrieval_k , ** kwargs
6562 )
66- addition_results_d , addition_results_i , updated_ids = parent_conn .recv ()
67- p .join ()
63+ addition_results_d , addition_results_i , updated_ids = future .result ()
6864
6965 # Filter updated vectors
7066 query_id = 0
@@ -106,16 +102,14 @@ def query(self, queries: np.ndarray, k, **kwargs):
106102
107103 @staticmethod
108104 def query_additions (
109- conn , queries : np .ndarray , k , dtype , update_arrays_uri , nthreads = 8
105+ queries : np .ndarray , k , dtype , update_arrays_uri , nthreads = 8
110106 ):
111107 assert queries .dtype == np .float32
112108 additions_vectors , additions_external_ids , updated_ids = Index .read_additions (
113109 update_arrays_uri
114110 )
115111 if additions_vectors is None :
116- conn .send (None , None , updated_ids )
117- conn .close ()
118- return
112+ return None , None , updated_ids
119113
120114 queries_m = array_to_matrix (np .transpose (queries ))
121115 d , i = query_vq_heap_pyarray (
@@ -125,8 +119,7 @@ def query_additions(
125119 k ,
126120 nthreads ,
127121 )
128- conn .send ((np .transpose (np .array (d )), np .transpose (np .array (i )), updated_ids ))
129- conn .close ()
122+ return np .transpose (np .array (d )), np .transpose (np .array (i )), updated_ids
130123
131124 @staticmethod
132125 def read_additions (update_arrays_uri ) -> (np .ndarray , np .array ):
0 commit comments