@@ -9,9 +9,10 @@ use failure::Error;
99use finalfusion:: metadata:: Metadata ;
1010use finalfusion:: prelude:: * ;
1111use finalfusion:: similarity:: * ;
12+ use finalfusion:: vocab:: WordIndex ;
1213use ndarray:: Array2 ;
1314use numpy:: { IntoPyArray , PyArray1 , PyArray2 } ;
14- use pyo3:: class:: { basic:: PyObjectProtocol , iter:: PyIterProtocol } ;
15+ use pyo3:: class:: { basic:: PyObjectProtocol , iter:: PyIterProtocol , sequence :: PySequenceProtocol } ;
1516use pyo3:: exceptions;
1617use pyo3:: prelude:: * ;
1718use toml:: { self , Value } ;
@@ -24,6 +25,7 @@ use toml::{self, Value};
2425fn finalfusion ( _py : Python , m : & PyModule ) -> PyResult < ( ) > {
2526 m. add_class :: < PyEmbeddings > ( ) ?;
2627 m. add_class :: < PyWordSimilarity > ( ) ?;
28+ m. add_class :: < PyVocab > ( ) ?;
2729 Ok ( ( ) )
2830}
2931
@@ -97,6 +99,13 @@ impl PyEmbeddings {
9799 Ok ( ( ) )
98100 }
99101
102+ /// Get the model's vocabulary.
103+ fn vocab ( & self ) -> PyResult < PyVocab > {
104+ Ok ( PyVocab {
105+ embeddings : self . embeddings . clone ( ) ,
106+ } )
107+ }
108+
100109 /// Perform an anology query.
101110 ///
102111 /// This returns words for the analogy query *w1* is to *w2*
@@ -309,6 +318,64 @@ impl PyIterProtocol for PyEmbeddings {
309318 }
310319}
311320
321+ /// finalfusion vocab.
322+ #[ pyclass( name=Vocab ) ]
323+ struct PyVocab {
324+ embeddings : Rc < RefCell < EmbeddingsWrap > > ,
325+ }
326+
327+ #[ pymethods]
328+ impl PyVocab {
329+ fn item_to_indices ( & self , key : String ) -> PyResult < PyObject > {
330+ let embeds = self . embeddings . borrow ( ) ;
331+
332+ use EmbeddingsWrap :: * ;
333+ let indices = match & * embeds {
334+ View ( e) => e. vocab ( ) . idx ( key. as_str ( ) ) ,
335+ NonView ( e) => e. vocab ( ) . idx ( key. as_str ( ) ) ,
336+ } ;
337+ indices
338+ . map ( |idx| {
339+ let gil = pyo3:: Python :: acquire_gil ( ) ;
340+ match idx {
341+ WordIndex :: Word ( idx) => [ idx] . to_object ( gil. python ( ) ) ,
342+ WordIndex :: Subword ( indices) => indices. to_object ( gil. python ( ) ) ,
343+ }
344+ } )
345+ . ok_or_else ( || exceptions:: KeyError :: py_err ( "Unknown word or n-grams" ) )
346+ }
347+ }
348+
349+ #[ pyproto]
350+ impl PySequenceProtocol for PyVocab {
351+ fn __len__ ( & self ) -> PyResult < usize > {
352+ let embeds = self . embeddings . borrow ( ) ;
353+
354+ use EmbeddingsWrap :: * ;
355+ match & * embeds {
356+ View ( e) => Ok ( e. vocab ( ) . words ( ) . len ( ) ) ,
357+ NonView ( e) => Ok ( e. vocab ( ) . words ( ) . len ( ) ) ,
358+ }
359+ }
360+
361+ fn __getitem__ ( & self , idx : isize ) -> PyResult < String > {
362+ let embeds = self . embeddings . borrow ( ) ;
363+
364+ use EmbeddingsWrap :: * ;
365+
366+ let v = match & * embeds {
367+ View ( e) => e. vocab ( ) ,
368+ NonView ( e) => e. vocab ( ) ,
369+ } ;
370+
371+ if idx >= v. words ( ) . len ( ) as isize || idx < 0 {
372+ Err ( exceptions:: IndexError :: py_err ( "list index out of range" ) )
373+ } else {
374+ Ok ( v. words ( ) [ idx as usize ] . clone ( ) )
375+ }
376+ }
377+ }
378+
312379fn load_embeddings < S > ( path : & str , mmap : bool ) -> Result < Embeddings < VocabWrap , S > , Error >
313380where
314381 Embeddings < VocabWrap , S > : ReadEmbeddings + MmapEmbeddings ,
0 commit comments