11#![ feature( specialization) ]
22
3+ use std:: cell:: RefCell ;
34use std:: fs:: File ;
45use std:: io:: BufReader ;
56use std:: rc:: Rc ;
67
78use failure:: Error ;
8- use finalfrontier:: similarity:: { Analogy , Similarity } ;
9- use finalfrontier:: { MmapModelBinary , Model , ReadModelBinary } ;
10- use ndarray:: Axis ;
119use pyo3:: class:: { basic:: PyObjectProtocol , iter:: PyIterProtocol } ;
1210use pyo3:: exceptions;
1311use pyo3:: prelude:: * ;
12+ use rust2vec:: metadata:: Metadata ;
13+ use rust2vec:: prelude:: * ;
14+ use rust2vec:: similarity:: * ;
15+ use toml:: { self , Value } ;
1416
15- /// This is a binding for finalfrontier .
17+ /// This is a Python module for using finalfusion embeddings .
1618///
17- /// finalfrontier is a library and set of programs for training
18- /// word embeddings with subword units. The Python binding can
19- /// be used to query the resulting embeddings and do similarity
20- /// queries.
19+ /// finalfusion is a format for word embeddings that supports words,
20+ /// subwords, memory-mapped matrices, and quantized matrices.
2121#[ pymodinit]
22- fn finalfrontier ( _py : Python , m : & PyModule ) -> PyResult < ( ) > {
23- m. add_class :: < PyModel > ( ) ?;
22+ fn finalfusion ( _py : Python , m : & PyModule ) -> PyResult < ( ) > {
23+ m. add_class :: < PyEmbeddings > ( ) ?;
2424 m. add_class :: < PyWordSimilarity > ( ) ?;
2525 Ok ( ( ) )
2626}
@@ -54,32 +54,46 @@ impl PyObjectProtocol for PyWordSimilarity {
5454 }
5555}
5656
57- /// A finalfrontier model.
58- #[ pyclass( name=Model ) ]
59- struct PyModel {
60- model : Rc < Model > ,
57+ enum EmbeddingsWrap {
58+ NonView ( Embeddings < VocabWrap , StorageWrap > ) ,
59+ View ( Embeddings < VocabWrap , StorageViewWrap > ) ,
60+ }
61+
62+ /// finalfusion embeddings.
63+ #[ pyclass( name=Embeddings ) ]
64+ struct PyEmbeddings {
65+ // The use of Rc + RefCell should be safe in this crate:
66+ //
67+ // 1. Python is single-threaded.
68+ // 2. The only mutable borrow (in set_metadata) is limited
69+ // to its method scope.
70+ // 3. None of the methods returns borrowed embeddings.
71+ embeddings : Rc < RefCell < EmbeddingsWrap > > ,
6172 token : PyToken ,
6273}
6374
6475#[ pymethods]
65- impl PyModel {
66- /// Load a model from the given `path`.
76+ impl PyEmbeddings {
77+ /// Load embeddings from the given `path`.
6778 ///
6879 /// When the `mmap` argument is `True`, the embedding matrix is
6980 /// not loaded into memory, but memory mapped. This results in
70- /// lower memory use and shorter model load times, while sacrificing
81+ /// lower memory use and shorter load times, while sacrificing
7182 /// some query efficiency.
7283 #[ new]
7384 #[ args( mmap = false ) ]
7485 fn __new__ ( obj : & PyRawObject , path : & str , mmap : bool ) -> PyResult < ( ) > {
75- let model = match load_model ( path, mmap) {
76- Ok ( model) => Rc :: new ( model) ,
77- Err ( err) => {
78- return Err ( exceptions:: IOError :: py_err ( err. to_string ( ) ) ) ;
79- }
86+ // First try to load embeddings with viewable storage. If that
87+ // fails, attempt to load the embeddings as non-viewable
88+ // storage.
89+ let embeddings = match load_embeddings ( path, mmap) {
90+ Ok ( e) => Rc :: new ( RefCell :: new ( EmbeddingsWrap :: View ( e) ) ) ,
91+ Err ( _) => load_embeddings ( path, mmap)
92+ . map ( |e| Rc :: new ( RefCell :: new ( EmbeddingsWrap :: NonView ( e) ) ) )
93+ . map_err ( |err| exceptions:: IOError :: py_err ( err. to_string ( ) ) ) ?,
8094 } ;
8195
82- obj. init ( |token| PyModel { model , token } )
96+ obj. init ( |token| PyEmbeddings { embeddings , token } )
8397 }
8498
8599 /// Perform an anology query.
@@ -95,9 +109,20 @@ impl PyModel {
95109 word3 : & str ,
96110 limit : usize ,
97111 ) -> PyResult < Vec < PyObject > > {
98- let results = match self . model . analogy ( word1, word2, word3, limit) {
112+ use EmbeddingsWrap :: * ;
113+ let embeddings = self . embeddings . borrow ( ) ;
114+ let embeddings = match & * embeddings {
115+ View ( e) => e,
116+ NonView ( _) => {
117+ return Err ( exceptions:: ValueError :: py_err (
118+ "Analogy queries are not supported for this type of embedding matrix" ,
119+ ) ) ;
120+ }
121+ } ;
122+
123+ let results = match embeddings. analogy ( word1, word2, word3, limit) {
99124 Some ( results) => results,
100- None => return Err ( exceptions:: KeyError :: py_err ( "Unknown word and n-grams" ) ) ,
125+ None => return Err ( exceptions:: KeyError :: py_err ( "Unknown word or n-grams" ) ) ,
101126 } ;
102127
103128 let mut r = Vec :: with_capacity ( results. len ( ) ) ;
@@ -120,16 +145,80 @@ impl PyModel {
120145 /// If the word is not known, its representation is approximated
121146 /// using subword units.
122147 fn embedding ( & self , word : & str ) -> PyResult < Vec < f32 > > {
123- match self . model . embedding ( word) {
124- Some ( embedding) => Ok ( embedding. to_vec ( ) ) ,
148+ let embeddings = self . embeddings . borrow ( ) ;
149+
150+ use EmbeddingsWrap :: * ;
151+ let embedding = match & * embeddings {
152+ View ( e) => e. embedding ( word) ,
153+ NonView ( e) => e. embedding ( word) ,
154+ } ;
155+
156+ match embedding {
157+ Some ( embedding) => Ok ( embedding. as_view ( ) . to_vec ( ) ) ,
125158 None => Err ( exceptions:: KeyError :: py_err ( "Unknown word and n-grams" ) ) ,
126159 }
127160 }
128161
162+ /// Embeddings metadata.
163+ #[ getter]
164+ fn metadata ( & self ) -> PyResult < Option < String > > {
165+ let embeddings = self . embeddings . borrow ( ) ;
166+
167+ use EmbeddingsWrap :: * ;
168+ let metadata = match & * embeddings {
169+ View ( e) => e. metadata ( ) ,
170+ NonView ( e) => e. metadata ( ) ,
171+ } ;
172+
173+ match metadata. map ( |v| toml:: ser:: to_string_pretty ( & v. 0 ) ) {
174+ Some ( Ok ( toml) ) => Ok ( Some ( toml) ) ,
175+ Some ( Err ( err) ) => Err ( exceptions:: IOError :: py_err ( format ! (
176+ "Metadata is invalid TOML: {}" ,
177+ err
178+ ) ) ) ,
179+ None => Ok ( None ) ,
180+ }
181+ }
182+
183+ #[ setter]
184+ fn set_metadata ( & mut self , metadata : & str ) -> PyResult < ( ) > {
185+ let value = match metadata. parse :: < Value > ( ) {
186+ Ok ( value) => value,
187+ Err ( err) => {
188+ return Err ( exceptions:: ValueError :: py_err ( format ! (
189+ "Metadata is invalid TOML: {}" ,
190+ err
191+ ) ) ) ;
192+ }
193+ } ;
194+
195+ let mut embeddings = self . embeddings . borrow_mut ( ) ;
196+
197+ use EmbeddingsWrap :: * ;
198+ match & mut * embeddings {
199+ View ( e) => e. set_metadata ( Some ( Metadata ( value) ) ) ,
200+ NonView ( e) => e. set_metadata ( Some ( Metadata ( value) ) ) ,
201+ } ;
202+
203+ Ok ( ( ) )
204+ }
205+
129206 /// Perform a similarity query.
130207 #[ args( limit = 10 ) ]
131208 fn similarity ( & self , py : Python , word : & str , limit : usize ) -> PyResult < Vec < PyObject > > {
132- let results = match self . model . similarity ( word, limit) {
209+ let embeddings = self . embeddings . borrow ( ) ;
210+
211+ use EmbeddingsWrap :: * ;
212+ let embeddings = match & * embeddings {
213+ View ( e) => e,
214+ NonView ( _) => {
215+ return Err ( exceptions:: ValueError :: py_err (
216+ "Similarity queries are not supported for this type of embedding matrix" ,
217+ ) ) ;
218+ }
219+ } ;
220+
221+ let results = match embeddings. similarity ( word, limit) {
133222 Some ( results) => results,
134223 None => return Err ( exceptions:: KeyError :: py_err ( "Unknown word and n-grams" ) ) ,
135224 } ;
@@ -151,12 +240,12 @@ impl PyModel {
151240}
152241
153242#[ pyproto]
154- impl PyIterProtocol for PyModel {
243+ impl PyIterProtocol for PyEmbeddings {
155244 fn __iter__ ( & mut self ) -> PyResult < PyObject > {
156245 let gil = Python :: acquire_gil ( ) ;
157246 let py = gil. python ( ) ;
158- let iter = Py :: new ( py, |token| PyModelIterator {
159- model : self . model . clone ( ) ,
247+ let iter = Py :: new ( py, |token| PyEmbeddingIterator {
248+ embeddings : self . embeddings . clone ( ) ,
160249 idx : 0 ,
161250 token,
162251 } ) ?
@@ -166,40 +255,55 @@ impl PyIterProtocol for PyModel {
166255 }
167256}
168257
169- fn load_model ( path : & str , mmap : bool ) -> Result < Model , Error > {
258+ fn load_embeddings < S > ( path : & str , mmap : bool ) -> Result < Embeddings < VocabWrap , S > , Error >
259+ where
260+ Embeddings < VocabWrap , S > : ReadEmbeddings + MmapEmbeddings ,
261+ {
170262 let f = File :: open ( path) ?;
263+ let mut reader = BufReader :: new ( f) ;
171264
172- let model = if mmap {
173- Model :: mmap_model_binary ( f ) ?
265+ let embeddings = if mmap {
266+ Embeddings :: mmap_embeddings ( & mut reader ) ?
174267 } else {
175- Model :: read_model_binary ( & mut BufReader :: new ( f ) ) ?
268+ Embeddings :: read_embeddings ( & mut reader ) ?
176269 } ;
177270
178- Ok ( model )
271+ Ok ( embeddings )
179272}
180273
181- #[ pyclass( name=ModelIterator ) ]
182- struct PyModelIterator {
183- model : Rc < Model > ,
274+ #[ pyclass( name=EmbeddingIterator ) ]
275+ struct PyEmbeddingIterator {
276+ embeddings : Rc < RefCell < EmbeddingsWrap > > ,
184277 idx : usize ,
185278 token : PyToken ,
186279}
187280
188281#[ pyproto]
189- impl PyIterProtocol for PyModelIterator {
282+ impl PyIterProtocol for PyEmbeddingIterator {
190283 fn __iter__ ( & mut self ) -> PyResult < PyObject > {
191284 Ok ( self . into ( ) )
192285 }
193286
194287 fn __next__ ( & mut self ) -> PyResult < Option < ( String , Vec < f32 > ) > > {
195- let vocab = self . model . vocab ( ) ;
196- let embeddings = self . model . embedding_matrix ( ) ;
288+ let embeddings = self . embeddings . borrow ( ) ;
289+
290+ use EmbeddingsWrap :: * ;
291+ let vocab = match & * embeddings {
292+ View ( e) => e. vocab ( ) ,
293+ NonView ( e) => e. vocab ( ) ,
294+ } ;
197295
198296 if self . idx < vocab. len ( ) {
199- let word = vocab. words ( ) [ self . idx ] . word ( ) . to_string ( ) ;
200- let embed = embeddings. subview ( Axis ( 0 ) , self . idx ) . to_vec ( ) ;
297+ let word = vocab. words ( ) [ self . idx ] . to_string ( ) ;
298+
299+ let embed = match & * embeddings {
300+ View ( e) => e. storage ( ) . embedding ( self . idx ) ,
301+ NonView ( e) => e. storage ( ) . embedding ( self . idx ) ,
302+ } ;
303+
201304 self . idx += 1 ;
202- Ok ( Some ( ( word, embed) ) )
305+
306+ Ok ( Some ( ( word, embed. as_view ( ) . to_vec ( ) ) ) )
203307 } else {
204308 Ok ( None )
205309 }
0 commit comments