1- // Copyright © 2024 Pathway
1+ // Copyright © 2025 Pathway
22
33use std:: collections:: HashMap ;
4+ use std:: sync:: Arc ;
45
56use crate :: async_runtime:: create_async_tokio_runtime;
67use crate :: engine:: error:: DynResult ;
@@ -14,7 +15,7 @@ use qdrant_client::qdrant::{
1415use qdrant_client:: Qdrant ;
1516
1617use super :: {
17- DerivedFilteredSearchIndex , ExternalIndex , ExternalIndexFactory , KeyScoreMatch ,
18+ DerivedFilteredSearchIndex , ExternalIndex , ExternalIndexFactory , IndexingError , KeyScoreMatch ,
1819 KeyToU64IdMapper , NonFilteringExternalIndex ,
1920} ;
2021
@@ -28,34 +29,34 @@ pub struct QdrantIndex {
2829
2930impl QdrantIndex {
3031 pub fn new (
31- url : String ,
32+ url : & str ,
3233 collection_name : String ,
3334 vector_size : usize ,
3435 api_key : Option < String > ,
35- ) -> DynResult < QdrantIndex > {
36+ ) -> Result < Self , Error > {
3637 let runtime = create_async_tokio_runtime ( )
3738 . map_err ( |e| Error :: Other ( format ! ( "Failed to create async runtime: {e}" ) . into ( ) ) ) ?;
3839
39- let client = Qdrant :: from_url ( & url)
40+ let client = Qdrant :: from_url ( url)
4041 . api_key ( api_key)
4142 . build ( )
42- . map_err ( |e| Error :: Other ( format ! ( "Failed to create Qdrant client: {e}" ) . into ( ) ) ) ?;
43-
44- let collection_exists = runtime
45- . block_on ( client. collection_exists ( & collection_name) )
46- . map_err ( |e| {
47- Error :: Other ( format ! ( "Failed to check collection existence: {e}" ) . into ( ) )
48- } ) ? ;
49-
50- if !collection_exists {
51- runtime
52- . block_on ( client . create_collection (
53- CreateCollectionBuilder :: new ( collection_name . clone ( ) ) . vectors_config (
54- VectorParamsBuilder :: new ( vector_size as u64 , Distance :: Cosine ) ,
55- ) ,
56- ) )
57- . map_err ( |e| Error :: Other ( format ! ( "Failed to create collection: {e}" ) . into ( ) ) ) ? ;
58- }
43+ . map_err ( IndexingError :: from ) ?;
44+
45+ runtime. block_on ( async {
46+ let exists = client. collection_exists ( & collection_name) . await ? ;
47+
48+ if !exists {
49+ client
50+ . create_collection (
51+ CreateCollectionBuilder :: new ( collection_name . clone ( ) ) . vectors_config (
52+ VectorParamsBuilder :: new ( vector_size as u64 , Distance :: Cosine ) ,
53+ ) ,
54+ )
55+ . await ? ;
56+ }
57+
58+ Ok :: < _ , IndexingError > ( ( ) )
59+ } ) ? ;
5960
6061 Ok ( QdrantIndex {
6162 client,
@@ -66,19 +67,23 @@ impl QdrantIndex {
6667 } )
6768 }
6869
69- fn search_one ( & self , data : & [ f64 ] , limit : usize ) -> DynResult < Vec < KeyScoreMatch > > {
70+ #[ allow( clippy:: cast_possible_truncation) ]
71+ async fn search_one_async (
72+ & self ,
73+ data : & [ f64 ] ,
74+ limit : usize ,
75+ ) -> Result < Vec < KeyScoreMatch > , IndexingError > {
7076 let query_vec: Vec < f32 > = data. iter ( ) . map ( |v| * v as f32 ) . collect ( ) ;
7177 let search_result = self
72- . runtime
73- . block_on (
74- self . client . query (
75- QueryPointsBuilder :: new ( & self . collection_name )
76- . query ( query_vec)
77- . limit ( limit as u64 )
78- . with_payload ( false ) ,
79- ) ,
78+ . client
79+ . query (
80+ QueryPointsBuilder :: new ( & self . collection_name )
81+ . query ( query_vec)
82+ . limit ( limit as u64 )
83+ . with_payload ( false ) ,
8084 )
81- . map_err ( |e| Error :: Other ( format ! ( "Search failed: {e}" ) . into ( ) ) ) ?;
85+ . await ?;
86+
8287 let mut results = Vec :: with_capacity ( search_result. result . len ( ) ) ;
8388 for point in search_result. result {
8489 let Some ( point_id) = point. id else {
@@ -113,69 +118,121 @@ impl QdrantIndex {
113118 Ok ( results)
114119 }
115120
116- fn add_one ( & mut self , key : Key , data : & [ f64 ] ) -> DynResult < ( ) > {
117- if data. len ( ) != self . vector_size {
118- return Err ( format ! (
119- "Vector size mismatch: expected {}, got {}" ,
120- self . vector_size,
121- data. len( )
122- )
123- . into ( ) ) ;
124- }
121+ #[ allow( clippy:: cast_possible_truncation) ]
122+ fn add_batch ( & mut self , data : Vec < ( Key , Vec < f64 > ) > ) -> Result < ( ) , IndexingError > {
123+ let mut points = Vec :: with_capacity ( data. len ( ) ) ;
125124
126- let key_id = self . key_to_id_mapper . get_next_free_u64_id ( key) ;
127- let vec_f32: Vec < f32 > = data. iter ( ) . map ( |v| * v as f32 ) . collect ( ) ;
125+ for ( key, vec_data) in data {
126+ if vec_data. len ( ) != self . vector_size {
127+ return Err ( IndexingError :: Io ( std:: io:: Error :: new (
128+ std:: io:: ErrorKind :: InvalidData ,
129+ format ! (
130+ "Vector size mismatch: expected {}, got {}" ,
131+ self . vector_size,
132+ vec_data. len( )
133+ ) ,
134+ ) ) ) ;
135+ }
128136
129- self . runtime
130- . block_on ( self . client . upsert_points ( UpsertPointsBuilder :: new (
131- & self . collection_name ,
132- vec ! [ PointStruct :: new(
133- key_id,
134- vec_f32,
135- HashMap :: <String , Value >:: new( ) ,
136- ) ] ,
137- ) ) )
138- . map_err ( |e| Error :: Other ( format ! ( "Failed to add point: {e}" ) . into ( ) ) ) ?;
137+ let key_id = self . key_to_id_mapper . get_next_free_u64_id ( key) ;
138+ let vec_f32: Vec < f32 > = vec_data. iter ( ) . map ( |v| * v as f32 ) . collect ( ) ;
139+ points. push ( PointStruct :: new (
140+ key_id,
141+ vec_f32,
142+ HashMap :: < String , Value > :: new ( ) ,
143+ ) ) ;
144+ }
145+
146+ self . runtime . block_on (
147+ self . client
148+ . upsert_points ( UpsertPointsBuilder :: new ( & self . collection_name , points) ) ,
149+ ) ?;
139150
140151 Ok ( ( ) )
141152 }
142153
143- fn remove_one ( & mut self , key : Key ) -> DynResult < ( ) > {
144- let key_id = self . key_to_id_mapper . remove_key ( key) ?;
154+ fn remove_batch ( & mut self , keys : Vec < Key > ) -> Result < Vec < u64 > , IndexingError > {
155+ let mut key_ids = Vec :: with_capacity ( keys. len ( ) ) ;
156+ let mut missing_keys = Vec :: new ( ) ;
145157
146- self . runtime
147- . block_on (
148- self . client . delete_points (
149- DeletePointsBuilder :: new ( & self . collection_name ) . points ( [ key_id] ) ,
150- ) ,
151- )
152- . map_err ( |e| Error :: Other ( format ! ( "Failed to remove point: {e}" ) . into ( ) ) ) ?;
158+ for key in keys {
159+ match self . key_to_id_mapper . remove_key ( key) {
160+ Ok ( key_id) => key_ids. push ( key_id) ,
161+ Err ( _) => missing_keys. push ( key) ,
162+ }
163+ }
153164
154- Ok ( ( ) )
165+ if !key_ids. is_empty ( ) {
166+ self . runtime . block_on ( self . client . delete_points (
167+ DeletePointsBuilder :: new ( & self . collection_name ) . points ( key_ids. clone ( ) ) ,
168+ ) ) ?;
169+ }
170+
171+ Ok ( key_ids)
155172 }
156173}
157174
158175impl NonFilteringExternalIndex < Vec < f64 > , Vec < f64 > > for QdrantIndex {
159176 fn add ( & mut self , add_data : Vec < ( Key , Vec < f64 > ) > ) -> Vec < ( Key , DynResult < ( ) > ) > {
160- add_data
161- . into_iter ( )
162- . map ( |( key, data) | ( key, self . add_one ( key, & data) ) )
163- . collect ( )
177+ if add_data. is_empty ( ) {
178+ return Vec :: new ( ) ;
179+ }
180+
181+ let keys: Vec < Key > = add_data. iter ( ) . map ( |( k, _) | * k) . collect ( ) ;
182+
183+ match self . add_batch ( add_data) {
184+ Ok ( ( ) ) => keys. into_iter ( ) . map ( |key| ( key, Ok ( ( ) ) ) ) . collect ( ) ,
185+ Err ( e) => {
186+ let shared_error: Arc < str > = Error :: from ( e) . to_string ( ) . into ( ) ;
187+ keys. into_iter ( )
188+ . map ( |key| ( key, Err ( Error :: Other ( shared_error. as_ref ( ) . into ( ) ) . into ( ) ) ) )
189+ . collect ( )
190+ }
191+ }
164192 }
165193
166194 fn remove ( & mut self , keys : Vec < Key > ) -> Vec < ( Key , DynResult < ( ) > ) > {
167- keys. into_iter ( )
168- . map ( |key| ( key, self . remove_one ( key) ) )
169- . collect ( )
195+ if keys. is_empty ( ) {
196+ return Vec :: new ( ) ;
197+ }
198+
199+ let original_keys = keys. clone ( ) ;
200+
201+ match self . remove_batch ( keys) {
202+ Ok ( _) => original_keys. into_iter ( ) . map ( |key| ( key, Ok ( ( ) ) ) ) . collect ( ) ,
203+ Err ( e) => {
204+ let shared_error: Arc < str > = Error :: from ( e) . to_string ( ) . into ( ) ;
205+ original_keys
206+ . into_iter ( )
207+ . map ( |key| ( key, Err ( Error :: Other ( shared_error. as_ref ( ) . into ( ) ) . into ( ) ) ) )
208+ . collect ( )
209+ }
210+ }
170211 }
171212
172213 fn search (
173214 & self ,
174215 queries : & [ ( Key , Vec < f64 > , usize ) ] ,
175216 ) -> Vec < ( Key , DynResult < Vec < KeyScoreMatch > > ) > {
176- queries
177- . iter ( )
178- . map ( |( key, data, limit) | ( * key, self . search_one ( data, * limit) ) )
217+ if queries. is_empty ( ) {
218+ return Vec :: new ( ) ;
219+ }
220+
221+ let keys: Vec < Key > = queries. iter ( ) . map ( |( k, _, _) | * k) . collect ( ) ;
222+
223+ let results = self . runtime . block_on ( async {
224+ let mut futures = Vec :: with_capacity ( queries. len ( ) ) ;
225+
226+ for ( _, data, limit) in queries {
227+ futures. push ( self . search_one_async ( data, * limit) ) ;
228+ }
229+
230+ futures:: future:: join_all ( futures) . await
231+ } ) ;
232+
233+ keys. into_iter ( )
234+ . zip ( results)
235+ . map ( |( key, result) | ( key, result. map_err ( |e| Error :: from ( e) . into ( ) ) ) )
179236 . collect ( )
180237 }
181238}
@@ -189,13 +246,13 @@ pub struct QdrantIndexFactory {
189246
190247impl QdrantIndexFactory {
191248 pub fn new (
192- url : String ,
249+ url : & str ,
193250 collection_name : String ,
194251 vector_size : usize ,
195252 api_key : Option < String > ,
196253 ) -> QdrantIndexFactory {
197254 QdrantIndexFactory {
198- url,
255+ url : url . to_string ( ) ,
199256 collection_name,
200257 vector_size,
201258 api_key,
@@ -206,7 +263,7 @@ impl QdrantIndexFactory {
206263impl ExternalIndexFactory for QdrantIndexFactory {
207264 fn make_instance ( & self ) -> Result < Box < dyn ExternalIndex > , Error > {
208265 let qdrant_index = QdrantIndex :: new (
209- self . url . clone ( ) ,
266+ & self . url ,
210267 self . collection_name . clone ( ) ,
211268 self . vector_size ,
212269 self . api_key . clone ( ) ,
0 commit comments