Skip to content

Commit b27c5d9

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Add embedding similarity queries.
Add a method to perform similarity queries based on an input embedding rather than words.
1 parent 3ec049c commit b27c5d9

File tree

3 files changed

+132
-30
lines changed

3 files changed

+132
-30
lines changed

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/embeddings.rs

Lines changed: 106 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::cell::RefCell;
2+
use std::collections::HashSet;
23
use std::fs::File;
34
use std::io::{BufReader, BufWriter};
45
use std::rc::Rc;
@@ -12,11 +13,11 @@ use finalfusion::prelude::*;
1213
use finalfusion::similarity::*;
1314
use itertools::Itertools;
1415
use ndarray::Array2;
15-
use numpy::{IntoPyArray, PyArray1, PyArray2};
16+
use numpy::{IntoPyArray, NpyDataType, PyArray1, PyArray2};
1617
use pyo3::class::iter::PyIterProtocol;
1718
use pyo3::prelude::*;
18-
use pyo3::types::PyTuple;
19-
use pyo3::{exceptions, PyMappingProtocol};
19+
use pyo3::types::{PyAny, PyList, PySet, PyTuple};
20+
use pyo3::{exceptions, PyMappingProtocol, PyTypeInfo};
2021
use toml::{self, Value};
2122

2223
use crate::{EmbeddingsWrap, PyEmbeddingIterator, PyVocab, PyWordSimilarity};
@@ -143,18 +144,7 @@ impl PyEmbeddings {
143144
exceptions::KeyError::py_err(format!("Unknown word or n-grams: {}", failed))
144145
})?;
145146

146-
let mut r = Vec::with_capacity(results.len());
147-
for ws in results {
148-
r.push(
149-
Py::new(
150-
py,
151-
PyWordSimilarity::new(ws.word.to_owned(), ws.similarity.into_inner()),
152-
)?
153-
.into_object(py),
154-
)
155-
}
156-
157-
Ok(r)
147+
Self::similarity_results(py, results)
158148
}
159149

160150
/// Get the embedding for the given word.
@@ -258,7 +248,7 @@ impl PyEmbeddings {
258248

259249
/// Perform a similarity query.
260250
#[args(limit = 10)]
261-
fn similarity(&self, py: Python, word: &str, limit: usize) -> PyResult<Vec<PyObject>> {
251+
fn word_similarity(&self, py: Python, word: &str, limit: usize) -> PyResult<Vec<PyObject>> {
262252
let embeddings = self.embeddings.borrow();
263253

264254
let embeddings = embeddings.view().ok_or_else(|| {
@@ -271,18 +261,46 @@ impl PyEmbeddings {
271261
.word_similarity(word, limit)
272262
.ok_or_else(|| exceptions::KeyError::py_err("Unknown word and n-grams"))?;
273263

274-
let mut r = Vec::with_capacity(results.len());
275-
for ws in results {
276-
r.push(
277-
Py::new(
278-
py,
279-
PyWordSimilarity::new(ws.word.to_owned(), ws.similarity.into_inner()),
280-
)?
281-
.into_object(py),
264+
Self::similarity_results(py, results)
265+
}
266+
267+
/// Perform a similarity query based on a query embedding.
268+
#[args(limit = 10, skip = "None")]
269+
fn embedding_similarity(
270+
&self,
271+
py: Python,
272+
embedding: PyEmbedding,
273+
skip: Option<Option<Skips>>,
274+
limit: usize,
275+
) -> PyResult<Vec<PyObject>> {
276+
let embeddings = self.embeddings.borrow();
277+
278+
let embeddings = embeddings.view().ok_or_else(|| {
279+
exceptions::ValueError::py_err(
280+
"Similarity queries are not supported for this type of embedding matrix",
282281
)
282+
})?;
283+
284+
let embedding = embedding.0.as_array();
285+
286+
if embedding.shape()[0] != embeddings.storage().shape().1 {
287+
return Err(exceptions::ValueError::py_err(format!(
288+
"Incompatible embedding shapes: embeddings: ({},), query: ({},)",
289+
embedding.shape()[0],
290+
embeddings.storage().shape().1
291+
)));
283292
}
284293

285-
Ok(r)
294+
let results = if let Some(Some(skip)) = skip {
295+
embeddings.embedding_similarity_masked(embedding, limit, &skip.0)
296+
} else {
297+
embeddings.embedding_similarity(embedding, limit)
298+
};
299+
300+
Self::similarity_results(
301+
py,
302+
results.ok_or_else(|| exceptions::KeyError::py_err("Unknown word and n-grams"))?,
303+
)
286304
}
287305

288306
/// Write the embeddings to a finalfusion file.
@@ -304,6 +322,25 @@ impl PyEmbeddings {
304322
}
305323
}
306324

325+
impl PyEmbeddings {
326+
fn similarity_results(
327+
py: Python,
328+
results: Vec<WordSimilarityResult>,
329+
) -> PyResult<Vec<PyObject>> {
330+
let mut r = Vec::with_capacity(results.len());
331+
for ws in results {
332+
r.push(
333+
Py::new(
334+
py,
335+
PyWordSimilarity::new(ws.word.to_owned(), ws.similarity.into_inner()),
336+
)?
337+
.into_object(py),
338+
)
339+
}
340+
Ok(r)
341+
}
342+
}
343+
307344
#[pyproto]
308345
impl PyMappingProtocol for PyEmbeddings {
309346
fn __getitem__(&self, word: &str) -> PyResult<Py<PyArray1<f32>>> {
@@ -372,3 +409,47 @@ where
372409
embeddings: Rc::new(RefCell::new(EmbeddingsWrap::View(embeddings.into()))),
373410
})
374411
}
412+
413+
struct Skips<'a>(HashSet<&'a str>);
414+
415+
impl<'a> FromPyObject<'a> for Skips<'a> {
416+
fn extract(ob: &'a PyAny) -> Result<Self, PyErr> {
417+
let mut set = ob
418+
.len()
419+
.map(|len| HashSet::with_capacity(len))
420+
.unwrap_or_default();
421+
422+
let iter = if <PySet as PyTypeInfo>::is_instance(ob) {
423+
ob.iter().unwrap()
424+
} else if <PyList as PyTypeInfo>::is_instance(ob) {
425+
ob.iter().unwrap()
426+
} else {
427+
return Err(exceptions::TypeError::py_err("Iterable expected"));
428+
};
429+
430+
for el in iter {
431+
set.insert(
432+
el?.extract()
433+
.map_err(|_| exceptions::TypeError::py_err("Expected String"))?,
434+
);
435+
}
436+
Ok(Skips(set))
437+
}
438+
}
439+
440+
struct PyEmbedding<'a>(&'a PyArray1<f32>);
441+
442+
impl<'a> FromPyObject<'a> for PyEmbedding<'a> {
443+
fn extract(ob: &'a PyAny) -> Result<Self, PyErr> {
444+
let embedding = ob
445+
.downcast_ref::<PyArray1<f32>>()
446+
.map_err(|_| exceptions::TypeError::py_err("Expected array with dtype Float32"))?;
447+
if embedding.data_type() != NpyDataType::Float32 {
448+
return Err(exceptions::TypeError::py_err(format!(
449+
"Expected dtype Float32, got {:?}",
450+
embedding.data_type()
451+
)));
452+
};
453+
Ok(PyEmbedding(embedding))
454+
}
455+
}

tests/test_similarity.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import pytest
2+
import numpy
3+
14
SIMILARITY_ORDER_STUTTGART_10 = [
25
"Karlsruhe",
36
"Mannheim",
@@ -57,10 +60,28 @@
5760

5861

5962
def test_similarity_berlin_40(similarity_fifu):
60-
for idx, sim in enumerate(similarity_fifu.similarity("Berlin", 40)):
63+
for idx, sim in enumerate(similarity_fifu.word_similarity("Berlin", 40)):
6164
assert SIMILARITY_ORDER[idx] == sim.word
6265

6366

6467
def test_similarity_stuttgart_10(similarity_fifu):
65-
for idx, sim in enumerate(similarity_fifu.similarity("Stuttgart", 10)):
68+
for idx, sim in enumerate(similarity_fifu.word_similarity("Stuttgart", 10)):
69+
assert SIMILARITY_ORDER_STUTTGART_10[idx] == sim.word
70+
71+
72+
def test_embedding_similarity_stuttgart_10(similarity_fifu):
73+
stuttgart = similarity_fifu.embedding("Stuttgart")
74+
sims = similarity_fifu.embedding_similarity(stuttgart, limit=10)
75+
assert sims[0].word == "Stuttgart"
76+
77+
for idx, sim in enumerate(sims[1:]):
6678
assert SIMILARITY_ORDER_STUTTGART_10[idx] == sim.word
79+
80+
for idx, sim in enumerate(similarity_fifu.embedding_similarity(stuttgart, skip={"Stuttgart"}, limit=10)):
81+
assert SIMILARITY_ORDER_STUTTGART_10[idx] == sim.word
82+
83+
84+
def test_embedding_similarity_incompatible_shapes(similarity_fifu):
85+
incompatible_embed = numpy.ones(1, dtype=numpy.float32)
86+
with pytest.raises(ValueError):
87+
similarity_fifu.embedding_similarity(incompatible_embed)

0 commit comments

Comments
 (0)