Skip to content

Commit e16c867

Browse files
committed
Add default value to embedding lookup.
Allow specification of a default return value for embedding lookups. Iterable, Array, Scalar and None are permitted values.
1 parent 63ded3d commit e16c867

File tree

2 files changed

+120
-9
lines changed

2 files changed

+120
-9
lines changed

src/embeddings.rs

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,19 @@ use finalfusion::io as ffio;
1111
use finalfusion::prelude::*;
1212
use finalfusion::similarity::*;
1313
use itertools::Itertools;
14+
use ndarray::Array1;
1415
use numpy::{IntoPyArray, NpyDataType, PyArray1};
1516
use pyo3::class::iter::PyIterProtocol;
1617
use pyo3::prelude::*;
17-
use pyo3::types::{PyAny, PyTuple};
18+
use pyo3::types::{PyAny, PyIterator, PyTuple};
1819
use pyo3::{exceptions, PyMappingProtocol};
1920
use toml::{self, Value};
2021

2122
use crate::storage::PyStorage;
2223
use crate::{EmbeddingsWrap, PyEmbeddingIterator, PyVocab, PyWordSimilarity};
2324

2425
/// finalfusion embeddings.
25-
#[pyclass(name=Embeddings)]
26+
#[pyclass(name = Embeddings)]
2627
pub struct PyEmbeddings {
2728
// The use of Rc + RefCell should be safe in this crate:
2829
//
@@ -179,17 +180,49 @@ impl PyEmbeddings {
179180
Self::similarity_results(py, results)
180181
}
181182

183+
/// embedding(word,/, default)
184+
/// --
185+
///
182186
/// Get the embedding for the given word.
183187
///
184188
/// If the word is not known, its representation is approximated
185-
/// using subword units.
186-
fn embedding(&self, word: &str) -> Option<Py<PyArray1<f32>>> {
189+
/// using subword units. #
190+
///
191+
/// If no representation can be calculated:
192+
/// - `None` if `default` is `None`
193+
/// - an array filled with `default` if `default` is a scalar
194+
/// - an array if `default` is a 1-d array
195+
/// - an array filled with values from `default` if it is an iterator over floats.
196+
#[args(default = "PyEmbeddingDefault::default()")]
197+
fn embedding(
198+
&self,
199+
word: &str,
200+
default: PyEmbeddingDefault,
201+
) -> PyResult<Option<Py<PyArray1<f32>>>> {
187202
let embeddings = self.embeddings.borrow();
203+
let gil = pyo3::Python::acquire_gil();
204+
if let PyEmbeddingDefault::Embedding(array) = &default {
205+
if array.as_ref(gil.python()).shape()[0] != embeddings.storage().shape().1 {
206+
return Err(exceptions::ValueError::py_err(format!(
207+
"Invalid shape of default embedding: {}",
208+
array.as_ref(gil.python()).shape()[0]
209+
)));
210+
}
211+
}
188212

189-
embeddings.embedding(word).map(|e| {
190-
let gil = pyo3::Python::acquire_gil();
191-
e.into_owned().into_pyarray(gil.python()).to_owned()
192-
})
213+
if let Some(embedding) = embeddings.embedding(word) {
214+
return Ok(Some(
215+
embedding.into_owned().into_pyarray(gil.python()).to_owned(),
216+
));
217+
};
218+
match default {
219+
PyEmbeddingDefault::Constant(constant) => {
220+
let nd_arr = Array1::from_elem([embeddings.storage().shape().1], constant);
221+
Ok(Some(nd_arr.into_pyarray(gil.python()).to_owned()))
222+
}
223+
PyEmbeddingDefault::Embedding(array) => Ok(Some(array)),
224+
PyEmbeddingDefault::None => Ok(None),
225+
}
193226
}
194227

195228
fn embedding_with_norm(&self, word: &str) -> Option<Py<PyTuple>> {
@@ -415,6 +448,58 @@ where
415448
})
416449
}
417450

451+
pub enum PyEmbeddingDefault {
452+
Embedding(Py<PyArray1<f32>>),
453+
Constant(f32),
454+
None,
455+
}
456+
457+
impl<'a> Default for PyEmbeddingDefault {
458+
fn default() -> Self {
459+
PyEmbeddingDefault::None
460+
}
461+
}
462+
463+
impl<'a> FromPyObject<'a> for PyEmbeddingDefault {
464+
fn extract(ob: &'a PyAny) -> Result<Self, PyErr> {
465+
if ob.is_none() {
466+
return Ok(PyEmbeddingDefault::None);
467+
}
468+
if let Ok(emb) = ob
469+
.extract()
470+
.map(|e: &PyArray1<f32>| PyEmbeddingDefault::Embedding(e.to_owned()))
471+
{
472+
return Ok(emb);
473+
}
474+
475+
if let Ok(constant) = ob.extract().map(PyEmbeddingDefault::Constant) {
476+
return Ok(constant);
477+
}
478+
if let Ok(embed) = ob
479+
.iter()
480+
.and_then(|iter| collect_array_from_py_iter(iter, ob.len().ok()))
481+
.map(PyEmbeddingDefault::Embedding)
482+
{
483+
return Ok(embed);
484+
}
485+
486+
Err(exceptions::TypeError::py_err(
487+
"failed to construct default value.",
488+
))
489+
}
490+
}
491+
492+
fn collect_array_from_py_iter(iter: PyIterator, len: Option<usize>) -> PyResult<Py<PyArray1<f32>>> {
493+
let mut embed_vec = len.map(Vec::with_capacity).unwrap_or_default();
494+
for item in iter {
495+
let item = item.and_then(|item| item.extract())?;
496+
embed_vec.push(item);
497+
}
498+
let gil = Python::acquire_gil();
499+
let embed = PyArray1::from_vec(gil.python(), embed_vec).to_owned();
500+
Ok(embed)
501+
}
502+
418503
struct Skips<'a>(HashSet<&'a str>);
419504

420505
impl<'a> FromPyObject<'a> for Skips<'a> {

tests/test_embeddings.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_embeddings(embeddings_fifu, embeddings_text, embeddings_text_dims):
2222
# The correct dimensionality of the other embedding types is asserted
2323
# in the pairwise comparisons below.
2424
assert fifu_storage.shape() == (7, 10)
25-
25+
2626
for embedding, storage_row in zip(embeddings_fifu, fifu_storage):
2727
assert numpy.allclose(
2828
embedding.embedding, embeddings_text[embedding.word]), "FiFu and text embedding mismatch"
@@ -32,6 +32,32 @@ def test_embeddings(embeddings_fifu, embeddings_text, embeddings_text_dims):
3232
embedding.embedding, storage_row), "FiFu and storage row mismatch"
3333

3434

35+
def test_unknown_embeddings(embeddings_fifu):
36+
assert embeddings_fifu.embedding("OOV") is None, "Unknown lookup with no default failed"
37+
assert embeddings_fifu.embedding(
38+
"OOV", default=None) is None, "Unknown lookup with 'None' default failed"
39+
assert numpy.allclose(embeddings_fifu.embedding(
40+
"OOV", default=[10]*10), numpy.array([10.]*10)), "Unknown lookup with 'list' default failed"
41+
assert numpy.allclose(embeddings_fifu.embedding("OOV", default=numpy.array(
42+
[10.]*10)), numpy.array([10.]*10)), "Unknown lookup with array default failed"
43+
assert numpy.allclose(embeddings_fifu.embedding(
44+
"OOV", default=10), numpy.array([10.]*10)), "Unknown lookup with 'int' scalar default failed"
45+
assert numpy.allclose(embeddings_fifu.embedding(
46+
"OOV", default=10.), numpy.array([10.]*10)), "Unknown lookup with 'float' scalar default failed"
47+
with pytest.raises(TypeError):
48+
embeddings_fifu.embedding(
49+
"OOV", default="not working"), "Unknown lookup with 'str' default succeeded"
50+
with pytest.raises(ValueError):
51+
embeddings_fifu.embedding(
52+
"OOV", default=[10.]*5), "Unknown lookup with incorrectly shaped 'list' default succeeded"
53+
with pytest.raises(ValueError):
54+
embeddings_fifu.embedding(
55+
"OOV", default=numpy.array([10.]*5)), "Unknown lookup with incorrectly shaped array default succeeded"
56+
with pytest.raises(ValueError):
57+
embeddings_fifu.embedding(
58+
"OOV", default=range(7)), "Unknown lookup with iterable default with incorrect number succeeded"
59+
60+
3561
def test_embeddings_pq(similarity_fifu, similarity_pq):
3662
for embedding in similarity_fifu:
3763
embedding_pq = similarity_pq.embedding("Berlin")

0 commit comments

Comments
 (0)