@@ -40,22 +40,43 @@ def __init__(
4040
4141    def  query (self , queries : np .ndarray , k , ** kwargs ):
4242        updated_ids  =  set (self .read_updated_ids ())
43-         internal_results_d , internal_results_i  =  self .query_internal (queries , k , ** kwargs )
43+         retrieval_k  =  k 
44+         if  len (updated_ids ) >  0 :
45+             retrieval_k  =  2 * k 
46+         internal_results_d , internal_results_i  =  self .query_internal (queries , retrieval_k , ** kwargs )
4447        if  self .update_arrays_uri  is  None :
45-             return  internal_results_d ,  internal_results_i 
48+             return  internal_results_d [:,  0 : k ],  internal_results_i [:,  0 : k ] 
4649
47-         addition_results_d , addition_results_i  =  self .query_additions (queries , k )
4850        # Filter updated vectors 
4951        query_id  =  0 
5052        for  query  in  internal_results_i :
5153            res_id  =  0 
5254            for  res  in  query :
5355                if  res  in  updated_ids :
5456                    internal_results_d [query_id , res_id ] =  MAX_FLOAT_32 
55-                     internal_results_i [query_id , res_id ] =  0 
57+                     internal_results_i [query_id , res_id ] =  MAX_UINT64 
5658                res_id  +=  1 
5759            query_id  +=  1 
60+         sort_index  =  np .argsort (internal_results_d , axis = 1 )
61+         internal_results_d  =  np .take_along_axis (internal_results_d , sort_index , axis = 1 )
62+         internal_results_i  =  np .take_along_axis (internal_results_i , sort_index , axis = 1 )
63+ 
5864        # Merge update results 
65+         addition_results_d , addition_results_i  =  self .query_additions (queries , k )
66+         if  addition_results_d  is  None :
67+             return  internal_results_d [:, 0 :k ], internal_results_i [:, 0 :k ]
68+ 
69+         query_id  =  0 
70+         for  query  in  addition_results_d :
71+             res_id  =  0 
72+             for  res  in  query :
73+                 if  addition_results_d [query_id , res_id ] ==  0  and  addition_results_i [query_id , res_id ] ==  0 :
74+                     addition_results_d [query_id , res_id ] =  MAX_FLOAT_32 
75+                     addition_results_i [query_id , res_id ] =  MAX_UINT64 
76+                 res_id  +=  1 
77+             query_id  +=  1 
78+ 
79+ 
5980        results_d  =  np .hstack ((internal_results_d , addition_results_d ))
6081        results_i  =  np .hstack ((internal_results_i , addition_results_i ))
6182        sort_index  =  np .argsort (results_d , axis = 1 )
@@ -69,6 +90,8 @@ def query_internal(self, queries: np.ndarray, k, **kwargs):
6990    def  query_additions (self , queries : np .ndarray , k ):
7091        assert  queries .dtype  ==  np .float32 
7192        additions_vectors , additions_external_ids  =  self .read_additions ()
93+         if  additions_vectors  is  None :
94+             return  None , None 
7295        queries_m  =  array_to_matrix (np .transpose (queries ))
7396        d , i  =  query_vq_heap_pyarray (
7497            array_to_matrix (np .transpose (additions_vectors ).astype (self .dtype )),
@@ -80,18 +103,25 @@ def query_additions(self, queries: np.ndarray, k):
80103
81104    def  update (self , vector : np .array , external_id : np .uint64 ):
82105        updates_array  =  self .open_updates_array ()
83-         updates_array [external_id ] =  vector 
106+         vectors  =  np .empty ((1 ), dtype = 'O' )
107+         vectors [0 ] =  vector 
108+         updates_array [external_id ] =  {'vector' : vectors }
84109        updates_array .close ()
110+         self .consolidate_update_fragments ()
85111
86112    def  update_batch (self , vectors : np .ndarray , external_ids : np .array ):
87113        updates_array  =  self .open_updates_array ()
88114        updates_array [external_ids ] =  {'vector' : vectors }
89115        updates_array .close ()
116+         self .consolidate_update_fragments ()
90117
91118    def  delete (self , external_id : np .uint64 ):
92119        updates_array  =  self .open_updates_array ()
93-         updates_array [external_id ] =  np .array ([], dtype = self .dtype )
120+         deletes  =  np .empty ((1 ), dtype = 'O' )
121+         deletes [0 ] =  np .array ([], dtype = self .dtype )
122+         updates_array [external_id ] =   {'vector' : deletes }
94123        updates_array .close ()
124+         self .consolidate_update_fragments ()
95125
96126    def  delete_batch (self , external_ids : np .array ):
97127        updates_array  =  self .open_updates_array ()
@@ -100,6 +130,13 @@ def delete_batch(self, external_ids: np.array):
100130            deletes [i ] =  np .array ([], dtype = self .dtype )
101131        updates_array [external_ids ] =  {'vector' : deletes }
102132        updates_array .close ()
133+         self .consolidate_update_fragments ()
134+ 
135+     def  consolidate_update_fragments (self ):
136+         fragments_info  =  tiledb .array_fragments (self .update_arrays_uri )
137+         if (len (fragments_info ) >  10 ):
138+             tiledb .consolidate (self .update_arrays_uri )
139+             tiledb .vacuum (self .update_arrays_uri )
103140
104141    def  get_updates_uri (self ):
105142        return  self .update_arrays_uri 
@@ -111,8 +148,10 @@ def read_additions(self) -> (np.ndarray, np.array):
111148        q  =  updates_array .query (attrs = ('vector' ,), coords = True )
112149        data  =  q [:]
113150        additions_filter  =  [len (item ) >  0  for  item  in  data ["vector" ]]
114-         return  np .vstack (data ["vector" ][additions_filter ]), data ["external_id" ][additions_filter ]
115- 
151+         if  len (data ["external_id" ][additions_filter ]) >  0 :
152+             return  np .vstack (data ["vector" ][additions_filter ]), data ["external_id" ][additions_filter ]
153+         else :
154+             return  None , None 
116155    def  read_updated_ids (self ) ->  np .array :
117156        if  self .update_arrays_uri  is  None :
118157            return  np .array ([], np .uint64 )
0 commit comments