@@ -11,18 +11,19 @@ use finalfusion::io as ffio;
1111use finalfusion:: prelude:: * ;
1212use finalfusion:: similarity:: * ;
1313use itertools:: Itertools ;
14+ use ndarray:: Array1 ;
1415use numpy:: { IntoPyArray , NpyDataType , PyArray1 } ;
1516use pyo3:: class:: iter:: PyIterProtocol ;
1617use pyo3:: prelude:: * ;
17- use pyo3:: types:: { PyAny , PyTuple } ;
18+ use pyo3:: types:: { PyAny , PyIterator , PyTuple } ;
1819use pyo3:: { exceptions, PyMappingProtocol } ;
1920use toml:: { self , Value } ;
2021
2122use crate :: storage:: PyStorage ;
2223use crate :: { EmbeddingsWrap , PyEmbeddingIterator , PyVocab , PyWordSimilarity } ;
2324
2425/// finalfusion embeddings.
25- #[ pyclass( name= Embeddings ) ]
26+ #[ pyclass( name = Embeddings ) ]
2627pub struct PyEmbeddings {
2728 // The use of Rc + RefCell should be safe in this crate:
2829 //
@@ -179,17 +180,49 @@ impl PyEmbeddings {
179180 Self :: similarity_results ( py, results)
180181 }
181182
183+ /// embedding(word,/, default)
184+ /// --
185+ ///
182186 /// Get the embedding for the given word.
183187 ///
184188 /// If the word is not known, its representation is approximated
185- /// using subword units.
186- fn embedding ( & self , word : & str ) -> Option < Py < PyArray1 < f32 > > > {
189+ /// using subword units. #
190+ ///
191+ /// If no representation can be calculated:
192+ /// - `None` if `default` is `None`
193+ /// - an array filled with `default` if `default` is a scalar
194+ /// - an array if `default` is a 1-d array
195+ /// - an array filled with values from `default` if it is an iterator over floats.
196+ #[ args( default = "PyEmbeddingDefault::default()" ) ]
197+ fn embedding (
198+ & self ,
199+ word : & str ,
200+ default : PyEmbeddingDefault ,
201+ ) -> PyResult < Option < Py < PyArray1 < f32 > > > > {
187202 let embeddings = self . embeddings . borrow ( ) ;
203+ let gil = pyo3:: Python :: acquire_gil ( ) ;
204+ if let PyEmbeddingDefault :: Embedding ( array) = & default {
205+ if array. as_ref ( gil. python ( ) ) . shape ( ) [ 0 ] != embeddings. storage ( ) . shape ( ) . 1 {
206+ return Err ( exceptions:: ValueError :: py_err ( format ! (
207+ "Invalid shape of default embedding: {}" ,
208+ array. as_ref( gil. python( ) ) . shape( ) [ 0 ]
209+ ) ) ) ;
210+ }
211+ }
188212
189- embeddings. embedding ( word) . map ( |e| {
190- let gil = pyo3:: Python :: acquire_gil ( ) ;
191- e. into_owned ( ) . into_pyarray ( gil. python ( ) ) . to_owned ( )
192- } )
213+ if let Some ( embedding) = embeddings. embedding ( word) {
214+ return Ok ( Some (
215+ embedding. into_owned ( ) . into_pyarray ( gil. python ( ) ) . to_owned ( ) ,
216+ ) ) ;
217+ } ;
218+ match default {
219+ PyEmbeddingDefault :: Constant ( constant) => {
220+ let nd_arr = Array1 :: from_elem ( [ embeddings. storage ( ) . shape ( ) . 1 ] , constant) ;
221+ Ok ( Some ( nd_arr. into_pyarray ( gil. python ( ) ) . to_owned ( ) ) )
222+ }
223+ PyEmbeddingDefault :: Embedding ( array) => Ok ( Some ( array) ) ,
224+ PyEmbeddingDefault :: None => Ok ( None ) ,
225+ }
193226 }
194227
195228 fn embedding_with_norm ( & self , word : & str ) -> Option < Py < PyTuple > > {
@@ -415,6 +448,58 @@ where
415448 } )
416449}
417450
451+ pub enum PyEmbeddingDefault {
452+ Embedding ( Py < PyArray1 < f32 > > ) ,
453+ Constant ( f32 ) ,
454+ None ,
455+ }
456+
457+ impl < ' a > Default for PyEmbeddingDefault {
458+ fn default ( ) -> Self {
459+ PyEmbeddingDefault :: None
460+ }
461+ }
462+
463+ impl < ' a > FromPyObject < ' a > for PyEmbeddingDefault {
464+ fn extract ( ob : & ' a PyAny ) -> Result < Self , PyErr > {
465+ if ob. is_none ( ) {
466+ return Ok ( PyEmbeddingDefault :: None ) ;
467+ }
468+ if let Ok ( emb) = ob
469+ . extract ( )
470+ . map ( |e : & PyArray1 < f32 > | PyEmbeddingDefault :: Embedding ( e. to_owned ( ) ) )
471+ {
472+ return Ok ( emb) ;
473+ }
474+
475+ if let Ok ( constant) = ob. extract ( ) . map ( PyEmbeddingDefault :: Constant ) {
476+ return Ok ( constant) ;
477+ }
478+ if let Ok ( embed) = ob
479+ . iter ( )
480+ . and_then ( |iter| collect_array_from_py_iter ( iter, ob. len ( ) . ok ( ) ) )
481+ . map ( PyEmbeddingDefault :: Embedding )
482+ {
483+ return Ok ( embed) ;
484+ }
485+
486+ Err ( exceptions:: TypeError :: py_err (
487+ "failed to construct default value." ,
488+ ) )
489+ }
490+ }
491+
492+ fn collect_array_from_py_iter ( iter : PyIterator , len : Option < usize > ) -> PyResult < Py < PyArray1 < f32 > > > {
493+ let mut embed_vec = len. map ( Vec :: with_capacity) . unwrap_or_default ( ) ;
494+ for item in iter {
495+ let item = item. and_then ( |item| item. extract ( ) ) ?;
496+ embed_vec. push ( item) ;
497+ }
498+ let gil = Python :: acquire_gil ( ) ;
499+ let embed = PyArray1 :: from_vec ( gil. python ( ) , embed_vec) . to_owned ( ) ;
500+ Ok ( embed)
501+ }
502+
418503struct Skips < ' a > ( HashSet < & ' a str > ) ;
419504
420505impl < ' a > FromPyObject < ' a > for Skips < ' a > {
0 commit comments