22
33extern crate failure;
44extern crate finalfrontier;
5+ extern crate ndarray;
56extern crate pyo3;
67
78use std:: fs:: File ;
89use std:: io:: BufReader ;
10+ use std:: rc:: Rc ;
911
1012use failure:: Error ;
1113use finalfrontier:: similarity:: { Analogy , Similarity } ;
12- use finalfrontier:: { MmapModelBinary , ReadModelBinary } ;
14+ use finalfrontier:: { MmapModelBinary , Model , ReadModelBinary } ;
15+ use ndarray:: Axis ;
1316use pyo3:: prelude:: * ;
1417
1518/// This is a binding for finalfrontier.
@@ -57,7 +60,7 @@ impl PyObjectProtocol for PyWordSimilarity {
5760/// A finalfrontier model.
5861#[ pyclass( name=Model ) ]
5962struct PyModel {
60- model : finalfrontier :: Model ,
63+ model : Rc < Model > ,
6164 token : PyToken ,
6265}
6366
@@ -73,9 +76,9 @@ impl PyModel {
7376 #[ args( mmap = false ) ]
7477 fn __new__ ( obj : & PyRawObject , path : & str , mmap : bool ) -> PyResult < ( ) > {
7578 let model = match load_model ( path, mmap) {
76- Ok ( model) => model,
79+ Ok ( model) => Rc :: new ( model) ,
7780 Err ( err) => {
78- return Err ( exc:: IOError :: new ( err. to_string ( ) ) ) ;
81+ return Err ( exc:: IOError :: py_err ( err. to_string ( ) ) ) ;
7982 }
8083 } ;
8184
@@ -97,7 +100,7 @@ impl PyModel {
97100 ) -> PyResult < Vec < PyObject > > {
98101 let results = match self . model . analogy ( word1, word2, word3, limit) {
99102 Some ( results) => results,
100- None => return Err ( exc:: KeyError :: new ( "Unknown word and n-grams" ) ) ,
103+ None => return Err ( exc:: KeyError :: py_err ( "Unknown word and n-grams" ) ) ,
101104 } ;
102105
103106 let mut r = Vec :: with_capacity ( results. len ( ) ) ;
@@ -121,7 +124,7 @@ impl PyModel {
121124 fn embedding ( & self , word : & str ) -> PyResult < Vec < f32 > > {
122125 match self . model . embedding ( word) {
123126 Some ( embedding) => Ok ( embedding. to_vec ( ) ) ,
124- None => Err ( exc:: KeyError :: new ( "Unknown word and n-grams" ) ) ,
127+ None => Err ( exc:: KeyError :: py_err ( "Unknown word and n-grams" ) ) ,
125128 }
126129 }
127130
@@ -130,7 +133,7 @@ impl PyModel {
130133 fn similarity ( & self , py : Python , word : & str , limit : usize ) -> PyResult < Vec < PyObject > > {
131134 let results = match self . model . similarity ( word, limit) {
132135 Some ( results) => results,
133- None => return Err ( exc:: KeyError :: new ( "Unknown word and n-grams" ) ) ,
136+ None => return Err ( exc:: KeyError :: py_err ( "Unknown word and n-grams" ) ) ,
134137 } ;
135138
136139 let mut r = Vec :: with_capacity ( results. len ( ) ) ;
@@ -148,14 +151,57 @@ impl PyModel {
148151 }
149152}
150153
151- fn load_model ( path : & str , mmap : bool ) -> Result < finalfrontier:: Model , Error > {
154+ #[ pyproto]
155+ impl PyIterProtocol for PyModel {
156+ fn __iter__ ( & mut self ) -> PyResult < PyObject > {
157+ let gil = Python :: acquire_gil ( ) ;
158+ let py = gil. python ( ) ;
159+ let iter = Py :: new ( py, |token| PyModelIterator {
160+ model : self . model . clone ( ) ,
161+ idx : 0 ,
162+ token,
163+ } ) ?. into_object ( py) ;
164+
165+ Ok ( iter)
166+ }
167+ }
168+
169+ fn load_model ( path : & str , mmap : bool ) -> Result < Model , Error > {
152170 let f = File :: open ( path) ?;
153171
154172 let model = if mmap {
155- finalfrontier :: Model :: mmap_model_binary ( f) ?
173+ Model :: mmap_model_binary ( f) ?
156174 } else {
157- finalfrontier :: Model :: read_model_binary ( & mut BufReader :: new ( f) ) ?
175+ Model :: read_model_binary ( & mut BufReader :: new ( f) ) ?
158176 } ;
159177
160178 Ok ( model)
161179}
180+
181+ #[ pyclass( name=ModelIterator ) ]
182+ struct PyModelIterator {
183+ model : Rc < Model > ,
184+ idx : usize ,
185+ token : PyToken ,
186+ }
187+
188+ #[ pyproto]
189+ impl PyIterProtocol for PyModelIterator {
190+ fn __iter__ ( & mut self ) -> PyResult < PyObject > {
191+ Ok ( self . into ( ) )
192+ }
193+
194+ fn __next__ ( & mut self ) -> PyResult < Option < ( String , Vec < f32 > ) > > {
195+ let vocab = self . model . vocab ( ) ;
196+ let embeddings = self . model . embedding_matrix ( ) ;
197+
198+ if self . idx < vocab. len ( ) {
199+ let word = vocab. words ( ) [ self . idx ] . word ( ) . to_string ( ) ;
200+ let embed = embeddings. subview ( Axis ( 0 ) , self . idx ) . to_vec ( ) ;
201+ self . idx += 1 ;
202+ Ok ( Some ( ( word, embed) ) )
203+ } else {
204+ Ok ( None )
205+ }
206+ }
207+ }
0 commit comments