Skip to content

Commit 4fefb6a

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Add Vocab to Python module.
1 parent 5cad188 commit 4fefb6a

File tree

1 file changed

+68
-1
lines changed

1 file changed

+68
-1
lines changed

src/lib.rs

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ use failure::Error;
99
use finalfusion::metadata::Metadata;
1010
use finalfusion::prelude::*;
1111
use finalfusion::similarity::*;
12+
use finalfusion::vocab::WordIndex;
1213
use ndarray::Array2;
1314
use numpy::{IntoPyArray, PyArray1, PyArray2};
14-
use pyo3::class::{basic::PyObjectProtocol, iter::PyIterProtocol};
15+
use pyo3::class::{basic::PyObjectProtocol, iter::PyIterProtocol, sequence::PySequenceProtocol};
1516
use pyo3::exceptions;
1617
use pyo3::prelude::*;
1718
use toml::{self, Value};
@@ -24,6 +25,7 @@ use toml::{self, Value};
2425
fn finalfusion(_py: Python, m: &PyModule) -> PyResult<()> {
2526
m.add_class::<PyEmbeddings>()?;
2627
m.add_class::<PyWordSimilarity>()?;
28+
m.add_class::<PyVocab>()?;
2729
Ok(())
2830
}
2931

@@ -97,6 +99,13 @@ impl PyEmbeddings {
9799
Ok(())
98100
}
99101

102+
/// Get the model's vocabulary.
103+
fn vocab(&self) -> PyResult<PyVocab> {
104+
Ok(PyVocab {
105+
embeddings: self.embeddings.clone(),
106+
})
107+
}
108+
100109
/// Perform an anology query.
101110
///
102111
/// This returns words for the analogy query *w1* is to *w2*
@@ -309,6 +318,64 @@ impl PyIterProtocol for PyEmbeddings {
309318
}
310319
}
311320

321+
/// finalfusion vocab.
322+
#[pyclass(name=Vocab)]
323+
struct PyVocab {
324+
embeddings: Rc<RefCell<EmbeddingsWrap>>,
325+
}
326+
327+
#[pymethods]
328+
impl PyVocab {
329+
fn item_to_indices(&self, key: String) -> PyResult<PyObject> {
330+
let embeds = self.embeddings.borrow();
331+
332+
use EmbeddingsWrap::*;
333+
let indices = match &*embeds {
334+
View(e) => e.vocab().idx(key.as_str()),
335+
NonView(e) => e.vocab().idx(key.as_str()),
336+
};
337+
indices
338+
.map(|idx| {
339+
let gil = pyo3::Python::acquire_gil();
340+
match idx {
341+
WordIndex::Word(idx) => [idx].to_object(gil.python()),
342+
WordIndex::Subword(indices) => indices.to_object(gil.python()),
343+
}
344+
})
345+
.ok_or_else(|| exceptions::KeyError::py_err("Unknown word or n-grams"))
346+
}
347+
}
348+
349+
#[pyproto]
350+
impl PySequenceProtocol for PyVocab {
351+
fn __len__(&self) -> PyResult<usize> {
352+
let embeds = self.embeddings.borrow();
353+
354+
use EmbeddingsWrap::*;
355+
match &*embeds {
356+
View(e) => Ok(e.vocab().words().len()),
357+
NonView(e) => Ok(e.vocab().words().len()),
358+
}
359+
}
360+
361+
fn __getitem__(&self, idx: isize) -> PyResult<String> {
362+
let embeds = self.embeddings.borrow();
363+
364+
use EmbeddingsWrap::*;
365+
366+
let v = match &*embeds {
367+
View(e) => e.vocab(),
368+
NonView(e) => e.vocab(),
369+
};
370+
371+
if idx >= v.words().len() as isize || idx < 0 {
372+
Err(exceptions::IndexError::py_err("list index out of range"))
373+
} else {
374+
Ok(v.words()[idx as usize].clone())
375+
}
376+
}
377+
}
378+
312379
fn load_embeddings<S>(path: &str, mmap: bool) -> Result<Embeddings<VocabWrap, S>, Error>
313380
where
314381
Embeddings<VocabWrap, S>: ReadEmbeddings + MmapEmbeddings,

0 commit comments

Comments
 (0)