@@ -40,22 +40,43 @@ def __init__(
4040
4141 def query (self , queries : np .ndarray , k , ** kwargs ):
4242 updated_ids = set (self .read_updated_ids ())
43- internal_results_d , internal_results_i = self .query_internal (queries , k , ** kwargs )
43+ retrieval_k = k
44+ if len (updated_ids ) > 0 :
45+ retrieval_k = 2 * k
46+ internal_results_d , internal_results_i = self .query_internal (queries , retrieval_k , ** kwargs )
4447 if self .update_arrays_uri is None :
45- return internal_results_d , internal_results_i
48+ return internal_results_d [:, 0 : k ], internal_results_i [:, 0 : k ]
4649
47- addition_results_d , addition_results_i = self .query_additions (queries , k )
4850 # Filter updated vectors
4951 query_id = 0
5052 for query in internal_results_i :
5153 res_id = 0
5254 for res in query :
5355 if res in updated_ids :
5456 internal_results_d [query_id , res_id ] = MAX_FLOAT_32
55- internal_results_i [query_id , res_id ] = 0
57+ internal_results_i [query_id , res_id ] = MAX_UINT64
5658 res_id += 1
5759 query_id += 1
60+ sort_index = np .argsort (internal_results_d , axis = 1 )
61+ internal_results_d = np .take_along_axis (internal_results_d , sort_index , axis = 1 )
62+ internal_results_i = np .take_along_axis (internal_results_i , sort_index , axis = 1 )
63+
5864 # Merge update results
65+ addition_results_d , addition_results_i = self .query_additions (queries , k )
66+ if addition_results_d is None :
67+ return internal_results_d [:, 0 :k ], internal_results_i [:, 0 :k ]
68+
69+ query_id = 0
70+ for query in addition_results_d :
71+ res_id = 0
72+ for res in query :
73+ if addition_results_d [query_id , res_id ] == 0 and addition_results_i [query_id , res_id ] == 0 :
74+ addition_results_d [query_id , res_id ] = MAX_FLOAT_32
75+ addition_results_i [query_id , res_id ] = MAX_UINT64
76+ res_id += 1
77+ query_id += 1
78+
79+
5980 results_d = np .hstack ((internal_results_d , addition_results_d ))
6081 results_i = np .hstack ((internal_results_i , addition_results_i ))
6182 sort_index = np .argsort (results_d , axis = 1 )
@@ -69,6 +90,8 @@ def query_internal(self, queries: np.ndarray, k, **kwargs):
6990 def query_additions (self , queries : np .ndarray , k ):
7091 assert queries .dtype == np .float32
7192 additions_vectors , additions_external_ids = self .read_additions ()
93+ if additions_vectors is None :
94+ return None , None
7295 queries_m = array_to_matrix (np .transpose (queries ))
7396 d , i = query_vq_heap_pyarray (
7497 array_to_matrix (np .transpose (additions_vectors ).astype (self .dtype )),
@@ -82,16 +105,19 @@ def update(self, vector: np.array, external_id: np.uint64):
82105 updates_array = self .open_updates_array ()
83106 updates_array [external_id ] = vector
84107 updates_array .close ()
108+ self .consolidate_update_fragments ()
85109
86110 def update_batch (self , vectors : np .ndarray , external_ids : np .array ):
87111 updates_array = self .open_updates_array ()
88112 updates_array [external_ids ] = {'vector' : vectors }
89113 updates_array .close ()
114+ self .consolidate_update_fragments ()
90115
91116 def delete (self , external_id : np .uint64 ):
92117 updates_array = self .open_updates_array ()
93118 updates_array [external_id ] = np .array ([], dtype = self .dtype )
94119 updates_array .close ()
120+ self .consolidate_update_fragments ()
95121
96122 def delete_batch (self , external_ids : np .array ):
97123 updates_array = self .open_updates_array ()
@@ -100,6 +126,13 @@ def delete_batch(self, external_ids: np.array):
100126 deletes [i ] = np .array ([], dtype = self .dtype )
101127 updates_array [external_ids ] = {'vector' : deletes }
102128 updates_array .close ()
129+ self .consolidate_update_fragments ()
130+
131+ def consolidate_update_fragments (self ):
132+ fragments_info = tiledb .array_fragments (self .update_arrays_uri )
133+ if (len (fragments_info ) > 10 ):
134+ tiledb .consolidate (self .update_arrays_uri )
135+ tiledb .vacuum (self .update_arrays_uri )
103136
104137 def get_updates_uri (self ):
105138 return self .update_arrays_uri
@@ -111,8 +144,10 @@ def read_additions(self) -> (np.ndarray, np.array):
111144 q = updates_array .query (attrs = ('vector' ,), coords = True )
112145 data = q [:]
113146 additions_filter = [len (item ) > 0 for item in data ["vector" ]]
114- return np .vstack (data ["vector" ][additions_filter ]), data ["external_id" ][additions_filter ]
115-
147+ if len (data ["external_id" ][additions_filter ]) > 0 :
148+ return np .vstack (data ["vector" ][additions_filter ]), data ["external_id" ][additions_filter ]
149+ else :
150+ return None , None
116151 def read_updated_ids (self ) -> np .array :
117152 if self .update_arrays_uri is None :
118153 return np .array ([], np .uint64 )
0 commit comments