Skip to content

Commit fa439bb

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Add PyMappingProtocol impl for PyEmbeddings.
1 parent 1d8eb8c commit fa439bb

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

src/embeddings.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ use finalfusion::similarity::*;
1010
use ndarray::Array2;
1111
use numpy::{IntoPyArray, PyArray1, PyArray2};
1212
use pyo3::class::iter::PyIterProtocol;
13-
use pyo3::exceptions;
1413
use pyo3::prelude::*;
1514
use pyo3::types::PyTuple;
15+
use pyo3::{exceptions, PyMappingProtocol};
1616
use toml::{self, Value};
1717

1818
use crate::{
@@ -114,13 +114,7 @@ impl PyEmbeddings {
114114
fn embedding(&self, word: &str) -> Option<Py<PyArray1<f32>>> {
115115
let embeddings = self.embeddings.borrow();
116116

117-
use EmbeddingsWrap::*;
118-
let embedding = match &*embeddings {
119-
View(e) => e.embedding(word),
120-
NonView(e) => e.embedding(word),
121-
};
122-
123-
embedding.map(|e| {
117+
embeddings.embedding(word).map(|e| {
124118
let gil = pyo3::Python::acquire_gil();
125119
e.into_owned().into_pyarray(gil.python()).to_owned()
126120
})
@@ -269,6 +263,21 @@ impl PyEmbeddings {
269263
}
270264
}
271265

266+
#[pyproto]
267+
impl PyMappingProtocol for PyEmbeddings {
268+
fn __getitem__(&self, word: &str) -> PyResult<Py<PyArray1<f32>>> {
269+
let embeddings = self.embeddings.borrow();
270+
271+
match embeddings.embedding(word) {
272+
Some(embedding) => {
273+
let gil = pyo3::Python::acquire_gil();
274+
Ok(embedding.into_owned().into_pyarray(gil.python()).to_owned())
275+
}
276+
None => Err(exceptions::KeyError::py_err("Unknown word and n-grams")),
277+
}
278+
}
279+
}
280+
272281
#[pyproto]
273282
impl PyIterProtocol for PyEmbeddings {
274283
fn __iter__(slf: PyRefMut<Self>) -> PyResult<PyObject> {

src/embeddings_wrap.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use finalfusion::norms::NdNorms;
22
use finalfusion::prelude::*;
3+
use finalfusion::storage::CowArray1;
34

45
pub enum EmbeddingsWrap {
56
NonView(Embeddings<VocabWrap, StorageWrap>),
@@ -30,4 +31,12 @@ impl EmbeddingsWrap {
3031
View(e) => e.norms(),
3132
}
3233
}
34+
35+
pub fn embedding(&self, word: &str) -> Option<CowArray1<f32>> {
36+
use EmbeddingsWrap::*;
37+
match self {
38+
View(e) => e.embedding(word),
39+
NonView(e) => e.embedding(word),
40+
}
41+
}
3342
}

tests/test_embedding.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def test_embeddings_with_norms_oov(embeddings_fifu):
2828
"Something out of vocabulary") is None
2929

3030

31+
def test_indexing(embeddings_fifu):
32+
assert embeddings_fifu["one"] is not None
33+
with pytest.raises(KeyError):
34+
embeddings_fifu["Something out of vocabulary"]
35+
36+
3137
def test_embeddings(embeddings_fifu, embeddings_text):
3238
for embedding_with_norm, norm in zip(embeddings_fifu, TEST_NORMS):
3339
unnormed_embed = embedding_with_norm[1] * norm

0 commit comments

Comments
 (0)