11use std:: cell:: RefCell ;
2+ use std:: collections:: HashSet ;
23use std:: fs:: File ;
34use std:: io:: { BufReader , BufWriter } ;
45use std:: rc:: Rc ;
@@ -12,11 +13,11 @@ use finalfusion::prelude::*;
1213use finalfusion:: similarity:: * ;
1314use itertools:: Itertools ;
1415use ndarray:: Array2 ;
15- use numpy:: { IntoPyArray , PyArray1 , PyArray2 } ;
16+ use numpy:: { IntoPyArray , NpyDataType , PyArray1 , PyArray2 } ;
1617use pyo3:: class:: iter:: PyIterProtocol ;
1718use pyo3:: prelude:: * ;
18- use pyo3:: types:: PyTuple ;
19- use pyo3:: { exceptions, PyMappingProtocol } ;
19+ use pyo3:: types:: { PyAny , PyList , PySet , PyTuple } ;
20+ use pyo3:: { exceptions, PyMappingProtocol , PyTypeInfo } ;
2021use toml:: { self , Value } ;
2122
2223use crate :: { EmbeddingsWrap , PyEmbeddingIterator , PyVocab , PyWordSimilarity } ;
@@ -143,18 +144,7 @@ impl PyEmbeddings {
143144 exceptions:: KeyError :: py_err ( format ! ( "Unknown word or n-grams: {}" , failed) )
144145 } ) ?;
145146
146- let mut r = Vec :: with_capacity ( results. len ( ) ) ;
147- for ws in results {
148- r. push (
149- Py :: new (
150- py,
151- PyWordSimilarity :: new ( ws. word . to_owned ( ) , ws. similarity . into_inner ( ) ) ,
152- ) ?
153- . into_object ( py) ,
154- )
155- }
156-
157- Ok ( r)
147+ Self :: similarity_results ( py, results)
158148 }
159149
160150 /// Get the embedding for the given word.
@@ -258,7 +248,7 @@ impl PyEmbeddings {
258248
259249 /// Perform a similarity query.
260250 #[ args( limit = 10 ) ]
261- fn similarity ( & self , py : Python , word : & str , limit : usize ) -> PyResult < Vec < PyObject > > {
251+ fn word_similarity ( & self , py : Python , word : & str , limit : usize ) -> PyResult < Vec < PyObject > > {
262252 let embeddings = self . embeddings . borrow ( ) ;
263253
264254 let embeddings = embeddings. view ( ) . ok_or_else ( || {
@@ -271,18 +261,46 @@ impl PyEmbeddings {
271261 . word_similarity ( word, limit)
272262 . ok_or_else ( || exceptions:: KeyError :: py_err ( "Unknown word and n-grams" ) ) ?;
273263
274- let mut r = Vec :: with_capacity ( results. len ( ) ) ;
275- for ws in results {
276- r. push (
277- Py :: new (
278- py,
279- PyWordSimilarity :: new ( ws. word . to_owned ( ) , ws. similarity . into_inner ( ) ) ,
280- ) ?
281- . into_object ( py) ,
264+ Self :: similarity_results ( py, results)
265+ }
266+
267+ /// Perform a similarity query based on a query embedding.
268+ #[ args( limit = 10 , skip = "None" ) ]
269+ fn embedding_similarity (
270+ & self ,
271+ py : Python ,
272+ embedding : PyEmbedding ,
273+ skip : Option < Option < Skips > > ,
274+ limit : usize ,
275+ ) -> PyResult < Vec < PyObject > > {
276+ let embeddings = self . embeddings . borrow ( ) ;
277+
278+ let embeddings = embeddings. view ( ) . ok_or_else ( || {
279+ exceptions:: ValueError :: py_err (
280+ "Similarity queries are not supported for this type of embedding matrix" ,
282281 )
282+ } ) ?;
283+
284+ let embedding = embedding. 0 . as_array ( ) ;
285+
286+ if embedding. shape ( ) [ 0 ] != embeddings. storage ( ) . shape ( ) . 1 {
287+ return Err ( exceptions:: ValueError :: py_err ( format ! (
288+ "Incompatible embedding shapes: embeddings: ({},), query: ({},)" ,
289+ embedding. shape( ) [ 0 ] ,
290+ embeddings. storage( ) . shape( ) . 1
291+ ) ) ) ;
283292 }
284293
285- Ok ( r)
294+ let results = if let Some ( Some ( skip) ) = skip {
295+ embeddings. embedding_similarity_masked ( embedding, limit, & skip. 0 )
296+ } else {
297+ embeddings. embedding_similarity ( embedding, limit)
298+ } ;
299+
300+ Self :: similarity_results (
301+ py,
302+ results. ok_or_else ( || exceptions:: KeyError :: py_err ( "Unknown word and n-grams" ) ) ?,
303+ )
286304 }
287305
288306 /// Write the embeddings to a finalfusion file.
@@ -304,6 +322,25 @@ impl PyEmbeddings {
304322 }
305323}
306324
325+ impl PyEmbeddings {
326+ fn similarity_results (
327+ py : Python ,
328+ results : Vec < WordSimilarityResult > ,
329+ ) -> PyResult < Vec < PyObject > > {
330+ let mut r = Vec :: with_capacity ( results. len ( ) ) ;
331+ for ws in results {
332+ r. push (
333+ Py :: new (
334+ py,
335+ PyWordSimilarity :: new ( ws. word . to_owned ( ) , ws. similarity . into_inner ( ) ) ,
336+ ) ?
337+ . into_object ( py) ,
338+ )
339+ }
340+ Ok ( r)
341+ }
342+ }
343+
307344#[ pyproto]
308345impl PyMappingProtocol for PyEmbeddings {
309346 fn __getitem__ ( & self , word : & str ) -> PyResult < Py < PyArray1 < f32 > > > {
@@ -372,3 +409,47 @@ where
372409 embeddings : Rc :: new ( RefCell :: new ( EmbeddingsWrap :: View ( embeddings. into ( ) ) ) ) ,
373410 } )
374411}
412+
413+ struct Skips < ' a > ( HashSet < & ' a str > ) ;
414+
415+ impl < ' a > FromPyObject < ' a > for Skips < ' a > {
416+ fn extract ( ob : & ' a PyAny ) -> Result < Self , PyErr > {
417+ let mut set = ob
418+ . len ( )
419+ . map ( |len| HashSet :: with_capacity ( len) )
420+ . unwrap_or_default ( ) ;
421+
422+ let iter = if <PySet as PyTypeInfo >:: is_instance ( ob) {
423+ ob. iter ( ) . unwrap ( )
424+ } else if <PyList as PyTypeInfo >:: is_instance ( ob) {
425+ ob. iter ( ) . unwrap ( )
426+ } else {
427+ return Err ( exceptions:: TypeError :: py_err ( "Iterable expected" ) ) ;
428+ } ;
429+
430+ for el in iter {
431+ set. insert (
432+ el?. extract ( )
433+ . map_err ( |_| exceptions:: TypeError :: py_err ( "Expected String" ) ) ?,
434+ ) ;
435+ }
436+ Ok ( Skips ( set) )
437+ }
438+ }
439+
440+ struct PyEmbedding < ' a > ( & ' a PyArray1 < f32 > ) ;
441+
442+ impl < ' a > FromPyObject < ' a > for PyEmbedding < ' a > {
443+ fn extract ( ob : & ' a PyAny ) -> Result < Self , PyErr > {
444+ let embedding = ob
445+ . downcast_ref :: < PyArray1 < f32 > > ( )
446+ . map_err ( |_| exceptions:: TypeError :: py_err ( "Expected array with dtype Float32" ) ) ?;
447+ if embedding. data_type ( ) != NpyDataType :: Float32 {
448+ return Err ( exceptions:: TypeError :: py_err ( format ! (
449+ "Expected dtype Float32, got {:?}" ,
450+ embedding. data_type( )
451+ ) ) ) ;
452+ } ;
453+ Ok ( PyEmbedding ( embedding) )
454+ }
455+ }
0 commit comments