@@ -9,6 +9,7 @@ use failure::Error;
99use finalfusion:: metadata:: Metadata ;
1010use finalfusion:: prelude:: * ;
1111use finalfusion:: similarity:: * ;
12+ use numpy:: { IntoPyArray , PyArray1 } ;
1213use pyo3:: class:: { basic:: PyObjectProtocol , iter:: PyIterProtocol } ;
1314use pyo3:: exceptions;
1415use pyo3:: prelude:: * ;
@@ -145,7 +146,7 @@ impl PyEmbeddings {
145146 ///
146147 /// If the word is not known, its representation is approximated
147148 /// using subword units.
148- fn embedding ( & self , word : & str ) -> PyResult < Vec < f32 > > {
149+ fn embedding ( & self , word : & str ) -> PyResult < Py < PyArray1 < f32 > > > {
149150 let embeddings = self . embeddings . borrow ( ) ;
150151
151152 use EmbeddingsWrap :: * ;
@@ -155,7 +156,10 @@ impl PyEmbeddings {
155156 } ;
156157
157158 match embedding {
158- Some ( embedding) => Ok ( embedding. as_view ( ) . to_vec ( ) ) ,
159+ Some ( embedding) => {
160+ let gil = pyo3:: Python :: acquire_gil ( ) ;
161+ Ok ( embedding. into_owned ( ) . into_pyarray ( gil. python ( ) ) . to_owned ( ) )
162+ }
159163 None => Err ( exceptions:: KeyError :: py_err ( "Unknown word and n-grams" ) ) ,
160164 }
161165 }
@@ -306,7 +310,7 @@ impl PyIterProtocol for PyEmbeddingIterator {
306310 Ok ( slf. into ( ) )
307311 }
308312
309- fn __next__ ( mut slf : PyRefMut < Self > ) -> PyResult < Option < ( String , Vec < f32 > ) > > {
313+ fn __next__ ( mut slf : PyRefMut < Self > ) -> PyResult < Option < ( String , Py < PyArray1 < f32 > > ) > > {
310314 let slf = & mut * slf;
311315
312316 let embeddings = slf. embeddings . borrow ( ) ;
@@ -327,7 +331,11 @@ impl PyIterProtocol for PyEmbeddingIterator {
327331
328332 slf. idx += 1 ;
329333
330- Ok ( Some ( ( word, embed. as_view ( ) . to_vec ( ) ) ) )
334+ let gil = pyo3:: Python :: acquire_gil ( ) ;
335+ Ok ( Some ( (
336+ word,
337+ embed. into_owned ( ) . into_pyarray ( gil. python ( ) ) . to_owned ( ) ,
338+ ) ) )
331339 } else {
332340 Ok ( None )
333341 }
0 commit comments