File tree Expand file tree Collapse file tree 1 file changed +6
-7
lines changed
Expand file tree Collapse file tree 1 file changed +6
-7
lines changed Original file line number Diff line number Diff line change @@ -310,12 +310,12 @@ impl PyEmbeddings {
310310 }
311311
312312 /// Perform a similarity query based on a query embedding.
313- #[ args( limit = 10 , skip = "None " ) ]
313+ #[ args( limit = 10 , skip = "Skips(HashSet::new()) " ) ]
314314 fn embedding_similarity (
315315 & self ,
316316 py : Python ,
317317 embedding : PyEmbedding ,
318- skip : Option < Option < Skips > > ,
318+ skip : Skips ,
319319 limit : usize ,
320320 ) -> PyResult < Vec < PyObject > > {
321321 let embeddings = self . embeddings . borrow ( ) ;
@@ -336,11 +336,7 @@ impl PyEmbeddings {
336336 ) ) ) ;
337337 }
338338
339- let results = if let Some ( Some ( skip) ) = skip {
340- embeddings. embedding_similarity_masked ( embedding, limit, & skip. 0 )
341- } else {
342- embeddings. embedding_similarity ( embedding, limit)
343- } ;
339+ let results = embeddings. embedding_similarity_masked ( embedding, limit, & skip. 0 ) ;
344340
345341 Self :: similarity_results (
346342 py,
@@ -462,6 +458,9 @@ struct Skips<'a>(HashSet<&'a str>);
462458impl < ' a > FromPyObject < ' a > for Skips < ' a > {
463459 fn extract ( ob : & ' a PyAny ) -> Result < Self , PyErr > {
464460 let mut set = ob. len ( ) . map ( HashSet :: with_capacity) . unwrap_or_default ( ) ;
461+ if ob. is_none ( ) {
462+ return Ok ( Skips ( set) ) ;
463+ }
465464 for el in ob
466465 . iter ( )
467466 . map_err ( |_| exceptions:: TypeError :: py_err ( "Iterable expected" ) ) ?
You can’t perform that action at this time.
0 commit comments