Skip to content

Commit 430e398

Browse files
danieldkDaniël de Kok
authored andcommitted
Return numpy arrays
1 parent 8075fa0 commit 430e398

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ features = ["extension-module"]
1616
failure = "0.1"
1717
finalfusion = "0.5"
1818
libc = "0.2"
19+
numpy = "0.5"
1920
toml = "0.4"

src/lib.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use failure::Error;
99
use finalfusion::metadata::Metadata;
1010
use finalfusion::prelude::*;
1111
use finalfusion::similarity::*;
12+
use numpy::{IntoPyArray, PyArray1};
1213
use pyo3::class::{basic::PyObjectProtocol, iter::PyIterProtocol};
1314
use pyo3::exceptions;
1415
use pyo3::prelude::*;
@@ -145,7 +146,7 @@ impl PyEmbeddings {
145146
///
146147
/// If the word is not known, its representation is approximated
147148
/// using subword units.
148-
fn embedding(&self, word: &str) -> PyResult<Vec<f32>> {
149+
fn embedding(&self, word: &str) -> PyResult<Py<PyArray1<f32>>> {
149150
let embeddings = self.embeddings.borrow();
150151

151152
use EmbeddingsWrap::*;
@@ -155,7 +156,10 @@ impl PyEmbeddings {
155156
};
156157

157158
match embedding {
158-
Some(embedding) => Ok(embedding.as_view().to_vec()),
159+
Some(embedding) => {
160+
let gil = pyo3::Python::acquire_gil();
161+
Ok(embedding.into_owned().into_pyarray(gil.python()).to_owned())
162+
}
159163
None => Err(exceptions::KeyError::py_err("Unknown word and n-grams")),
160164
}
161165
}
@@ -306,7 +310,7 @@ impl PyIterProtocol for PyEmbeddingIterator {
306310
Ok(slf.into())
307311
}
308312

309-
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<(String, Vec<f32>)>> {
313+
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<(String, Py<PyArray1<f32>>)>> {
310314
let slf = &mut *slf;
311315

312316
let embeddings = slf.embeddings.borrow();
@@ -327,7 +331,11 @@ impl PyIterProtocol for PyEmbeddingIterator {
327331

328332
slf.idx += 1;
329333

330-
Ok(Some((word, embed.as_view().to_vec())))
334+
let gil = pyo3::Python::acquire_gil();
335+
Ok(Some((
336+
word,
337+
embed.into_owned().into_pyarray(gil.python()).to_owned(),
338+
)))
331339
} else {
332340
Ok(None)
333341
}

0 commit comments

Comments
 (0)