Skip to content

Commit 3869c64

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Return None instead of raising exceptions for missing items.
Change some methods to return None rather than raising exceptions when an item is missing.
1 parent 5562836 commit 3869c64

File tree

4 files changed

+42
-35
lines changed

4 files changed

+42
-35
lines changed

src/embeddings.rs

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ impl PyEmbeddings {
111111
///
112112
/// If the word is not known, its representation is approximated
113113
/// using subword units.
114-
fn embedding(&self, word: &str) -> PyResult<Py<PyArray1<f32>>> {
114+
fn embedding(&self, word: &str) -> Option<Py<PyArray1<f32>>> {
115115
let embeddings = self.embeddings.borrow();
116116

117117
use EmbeddingsWrap::*;
@@ -120,16 +120,13 @@ impl PyEmbeddings {
120120
NonView(e) => e.embedding(word),
121121
};
122122

123-
match embedding {
124-
Some(embedding) => {
125-
let gil = pyo3::Python::acquire_gil();
126-
Ok(embedding.into_owned().into_pyarray(gil.python()).to_owned())
127-
}
128-
None => Err(exceptions::KeyError::py_err("Unknown word and n-grams")),
129-
}
123+
embedding.map(|e| {
124+
let gil = pyo3::Python::acquire_gil();
125+
e.into_owned().into_pyarray(gil.python()).to_owned()
126+
})
130127
}
131128

132-
fn embedding_with_norm(&self, word: &str) -> PyResult<Py<PyTuple>> {
129+
fn embedding_with_norm(&self, word: &str) -> Option<Py<PyTuple>> {
133130
let embeddings = self.embeddings.borrow();
134131

135132
use EmbeddingsWrap::*;
@@ -138,22 +135,15 @@ impl PyEmbeddings {
138135
NonView(e) => e.embedding_with_norm(word),
139136
};
140137

141-
match embedding_with_norm {
142-
Some(embedding_with_norm) => {
143-
let gil = pyo3::Python::acquire_gil();
144-
let py = gil.python();
145-
Ok((
146-
embedding_with_norm.embedding.into_owned().into_pyarray(py),
147-
embedding_with_norm.norm,
148-
)
149-
.into_py(py))
150-
}
151-
None => Err(exceptions::KeyError::py_err("Unknown word and n-grams")),
152-
}
138+
embedding_with_norm.map(|e| {
139+
let gil = pyo3::Python::acquire_gil();
140+
let embedding = e.embedding.into_owned().into_pyarray(gil.python());
141+
(embedding, e.norm).into_py(gil.python())
142+
})
153143
}
154144

155145
/// Copy the entire embeddings matrix.
156-
fn matrix_copy(&self) -> PyResult<Py<PyArray2<f32>>> {
146+
fn matrix_copy(&self) -> Py<PyArray2<f32>> {
157147
let embeddings = self.embeddings.borrow();
158148

159149
use EmbeddingsWrap::*;
@@ -175,7 +165,7 @@ impl PyEmbeddings {
175165
},
176166
};
177167
let gil = pyo3::Python::acquire_gil();
178-
Ok(matrix.into_pyarray(gil.python()).to_owned())
168+
matrix.into_pyarray(gil.python()).to_owned()
179169
}
180170

181171
/// Embeddings metadata.

src/vocab.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,16 @@ impl PyVocab {
2323

2424
#[pymethods]
2525
impl PyVocab {
26-
fn item_to_indices(&self, key: String) -> PyResult<PyObject> {
26+
fn item_to_indices(&self, key: String) -> Option<PyObject> {
2727
let embeds = self.embeddings.borrow();
2828

29-
embeds
30-
.vocab()
31-
.idx(key.as_str())
32-
.map(|idx| {
33-
let gil = pyo3::Python::acquire_gil();
34-
match idx {
35-
WordIndex::Word(idx) => [idx].to_object(gil.python()),
36-
WordIndex::Subword(indices) => indices.to_object(gil.python()),
37-
}
38-
})
39-
.ok_or_else(|| exceptions::KeyError::py_err("Unknown word or n-grams"))
29+
embeds.vocab().idx(key.as_str()).map(|idx| {
30+
let gil = pyo3::Python::acquire_gil();
31+
match idx {
32+
WordIndex::Word(idx) => [idx].to_object(gil.python()),
33+
WordIndex::Subword(indices) => indices.to_object(gil.python()),
34+
}
35+
})
4036
}
4137
}
4238

tests/test_embedding.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def test_embeddings_with_norms():
3131
embedding_with_norm) == 3, "The number of values returned by 'iter_with_norm()' does not match!"
3232

3333

34+
def test_embeddings_with_norms_oov():
35+
embeds = finalfusion.Embeddings(
36+
"tests/embeddings.fifu")
37+
assert embeds.embedding_with_norm("Something out of vocabulary") is None
38+
39+
3440
def test_embeddings():
3541
embeds = finalfusion.Embeddings(
3642
"tests/embeddings.fifu")
@@ -49,6 +55,12 @@ def test_embeddings():
4955
embedding_with_norm) == 2, "The number of values returned by normal iterator does not match!"
5056

5157

58+
def test_embeddings_oov():
59+
embeds = finalfusion.Embeddings(
60+
"tests/embeddings.fifu")
61+
assert embeds.embedding("Something out of vocabulary") is None
62+
63+
5264
def test_norms():
5365
embeds = finalfusion.Embeddings(
5466
"tests/embeddings.fifu")

tests/test_vocab.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import finalfusion
2+
import pytest
3+
4+
5+
def test_embeddings_with_norms_oov():
6+
embeds = finalfusion.Embeddings(
7+
"tests/embeddings.fifu")
8+
vocab = embeds.vocab()
9+
assert vocab.item_to_indices("Something out of vocabulary") is None

0 commit comments

Comments
 (0)