@@ -4,9 +4,11 @@ use std::rc::Rc;
44use finalfusion:: chunks:: vocab:: { NGramIndices , SubwordIndices , VocabWrap , WordIndex } ;
55use finalfusion:: prelude:: * ;
66use pyo3:: class:: sequence:: PySequenceProtocol ;
7- use pyo3:: exceptions;
7+ use pyo3:: exceptions:: { IndexError , KeyError , ValueError } ;
88use pyo3:: prelude:: * ;
9+ use pyo3:: { PyIterProtocol , PyMappingProtocol } ;
910
11+ use crate :: iter:: PyVocabIterator ;
1012use crate :: EmbeddingsWrap ;
1113
1214type NGramIndex = ( String , Option < usize > ) ;
@@ -25,16 +27,18 @@ impl PyVocab {
2527
2628#[ pymethods]
2729impl PyVocab {
28- fn item_to_indices ( & self , key : String ) -> Option < PyObject > {
30+ #[ args( default = "Python::acquire_gil().python().None()" ) ]
31+ fn get ( & self , key : & str , default : PyObject ) -> Option < PyObject > {
2932 let embeds = self . embeddings . borrow ( ) ;
30-
31- embeds. vocab ( ) . idx ( key. as_str ( ) ) . map ( |idx| {
32- let gil = pyo3:: Python :: acquire_gil ( ) ;
33- match idx {
34- WordIndex :: Word ( idx) => [ idx] . to_object ( gil. python ( ) ) ,
35- WordIndex :: Subword ( indices) => indices. to_object ( gil. python ( ) ) ,
36- }
37- } )
33+ let gil = pyo3:: Python :: acquire_gil ( ) ;
34+ let idx = embeds. vocab ( ) . idx ( key) . map ( |idx| match idx {
35+ WordIndex :: Word ( idx) => idx. to_object ( gil. python ( ) ) ,
36+ WordIndex :: Subword ( indices) => indices. to_object ( gil. python ( ) ) ,
37+ } ) ;
38+ if !default. is_none ( ) && idx. is_none ( ) {
39+ return Some ( default) ;
40+ }
41+ idx
3842 }
3943
4044 fn ngram_indices ( & self , word : & str ) -> PyResult < Option < Vec < NGramIndex > > > {
@@ -44,7 +48,7 @@ impl PyVocab {
4448 VocabWrap :: FinalfusionSubwordVocab ( inner) => inner. ngram_indices ( word) ,
4549 VocabWrap :: FinalfusionNGramVocab ( inner) => inner. ngram_indices ( word) ,
4650 VocabWrap :: SimpleVocab ( _) => {
47- return Err ( exceptions :: ValueError :: py_err (
51+ return Err ( ValueError :: py_err (
4852 "querying n-gram indices is not supported for this vocabulary" ,
4953 ) )
5054 }
@@ -57,29 +61,71 @@ impl PyVocab {
5761 VocabWrap :: FastTextSubwordVocab ( inner) => Ok ( inner. subword_indices ( word) ) ,
5862 VocabWrap :: FinalfusionSubwordVocab ( inner) => Ok ( inner. subword_indices ( word) ) ,
5963 VocabWrap :: FinalfusionNGramVocab ( inner) => Ok ( inner. subword_indices ( word) ) ,
60- VocabWrap :: SimpleVocab ( _) => Err ( exceptions :: ValueError :: py_err (
64+ VocabWrap :: SimpleVocab ( _) => Err ( ValueError :: py_err (
6165 "querying subwords' indices is not supported for this vocabulary" ,
6266 ) ) ,
6367 }
6468 }
6569}
6670
67- #[ pyproto]
68- impl PySequenceProtocol for PyVocab {
69- fn __len__ ( & self ) -> PyResult < usize > {
71+ impl PyVocab {
72+ fn str_to_indices ( & self , query : & str ) -> PyResult < WordIndex > {
7073 let embeds = self . embeddings . borrow ( ) ;
71- Ok ( embeds. vocab ( ) . words_len ( ) )
74+ embeds
75+ . vocab ( )
76+ . idx ( query)
77+ . ok_or_else ( || KeyError :: py_err ( format ! ( "key not found: '{}'" , query) ) )
7278 }
7379
74- fn __getitem__ ( & self , idx : isize ) -> PyResult < String > {
80+ fn validate_and_convert_isize_idx ( & self , mut idx : isize ) -> PyResult < usize > {
7581 let embeds = self . embeddings . borrow ( ) ;
76- let words = embeds. vocab ( ) . words ( ) ;
82+ let vocab = embeds. vocab ( ) ;
83+ if idx < 0 {
84+ idx += vocab. words_len ( ) as isize ;
85+ }
7786
78- if idx >= words . len ( ) as isize || idx < 0 {
79- Err ( exceptions :: IndexError :: py_err ( "list index out of range" ) )
87+ if idx >= vocab . words_len ( ) as isize || idx < 0 {
88+ Err ( IndexError :: py_err ( "list index out of range" ) )
8089 } else {
81- Ok ( words[ idx as usize ] . clone ( ) )
90+ Ok ( idx as usize )
91+ }
92+ }
93+ }
94+
95+ #[ pyproto]
96+ impl PyMappingProtocol for PyVocab {
97+ fn __getitem__ ( & self , query : PyObject ) -> PyResult < PyObject > {
98+ let embeds = self . embeddings . borrow ( ) ;
99+ let vocab = embeds. vocab ( ) ;
100+ let gil = Python :: acquire_gil ( ) ;
101+ if let Ok ( idx) = query. extract :: < isize > ( gil. python ( ) ) {
102+ let idx = self . validate_and_convert_isize_idx ( idx) ?;
103+ return Ok ( vocab. words ( ) [ idx] . clone ( ) . into_py ( gil. python ( ) ) ) ;
104+ }
105+
106+ if let Ok ( query) = query. extract :: < String > ( gil. python ( ) ) {
107+ return self . str_to_indices ( & query) . map ( |idx| match idx {
108+ WordIndex :: Subword ( indices) => indices. into_py ( gil. python ( ) ) ,
109+ WordIndex :: Word ( idx) => idx. into_py ( gil. python ( ) ) ,
110+ } ) ;
82111 }
112+
113+ Err ( KeyError :: py_err ( "key must be integer or string" ) )
114+ }
115+ }
116+
117+ #[ pyproto]
118+ impl PyIterProtocol for PyVocab {
119+ fn __iter__ ( slf : PyRefMut < Self > ) -> PyResult < PyVocabIterator > {
120+ Ok ( PyVocabIterator :: new ( slf. embeddings . clone ( ) , 0 ) )
121+ }
122+ }
123+
124+ #[ pyproto]
125+ impl PySequenceProtocol for PyVocab {
126+ fn __len__ ( & self ) -> PyResult < usize > {
127+ let embeds = self . embeddings . borrow ( ) ;
128+ Ok ( embeds. vocab ( ) . words_len ( ) )
83129 }
84130
85131 fn __contains__ ( & self , word : String ) -> PyResult < bool > {
0 commit comments