Skip to content

Commit 21e071c

Browse files
danieldkDaniël de Kok
authored andcommitted
Stop returning tuples in iteration over embeddings
When iterating over embeddings with norms, we get 3-tuples. This means that the user has to memorize which tuple element is what. So, instead produce instances of a EmbeddingWithNorm class which has the properties word/embedding/norm. For consistency, update iteration without norms as well.
1 parent 9afa6ac commit 21e071c

File tree

3 files changed

+70
-19
lines changed

3 files changed

+70
-19
lines changed

src/iter.rs

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ impl PyIterProtocol for PyEmbeddingIterator {
2626
Ok(slf.into())
2727
}
2828

29-
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<(String, Py<PyArray1<f32>>)>> {
29+
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<PyEmbedding>> {
3030
let slf = &mut *slf;
3131

3232
let embeddings = slf.embeddings.borrow();
@@ -39,16 +39,39 @@ impl PyIterProtocol for PyEmbeddingIterator {
3939
slf.idx += 1;
4040

4141
let gil = pyo3::Python::acquire_gil();
42-
Ok(Some((
42+
Ok(Some(PyEmbedding {
4343
word,
44-
embed.into_owned().into_pyarray(gil.python()).to_owned(),
45-
)))
44+
embedding: embed.into_owned().into_pyarray(gil.python()).to_owned(),
45+
}))
4646
} else {
4747
Ok(None)
4848
}
4949
}
5050
}
5151

52+
/// A word and its embedding.
53+
#[pyclass(name=Embedding)]
54+
pub struct PyEmbedding {
55+
embedding: Py<PyArray1<f32>>,
56+
word: String,
57+
}
58+
59+
#[pymethods]
60+
impl PyEmbedding {
61+
/// Get the embedding.
62+
#[getter]
63+
pub fn get_embedding(&self) -> Py<PyArray1<f32>> {
64+
let gil = Python::acquire_gil();
65+
self.embedding.clone_ref(gil.python())
66+
}
67+
68+
/// Get the word.
69+
#[getter]
70+
pub fn get_word(&self) -> &str {
71+
&self.word
72+
}
73+
}
74+
5275
#[pyclass(name=EmbeddingWithNormIterator)]
5376
pub struct PyEmbeddingWithNormIterator {
5477
embeddings: Rc<RefCell<EmbeddingsWrap>>,
@@ -67,7 +90,7 @@ impl PyIterProtocol for PyEmbeddingWithNormIterator {
6790
Ok(slf.into())
6891
}
6992

70-
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<(String, Py<PyArray1<f32>>, f32)>> {
93+
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<PyEmbeddingWithNorm>> {
7194
let slf = &mut *slf;
7295

7396
let embeddings = slf.embeddings.borrow();
@@ -81,13 +104,43 @@ impl PyIterProtocol for PyEmbeddingWithNormIterator {
81104
slf.idx += 1;
82105

83106
let gil = pyo3::Python::acquire_gil();
84-
Ok(Some((
107+
Ok(Some(PyEmbeddingWithNorm {
85108
word,
86-
embed.into_owned().into_pyarray(gil.python()).to_owned(),
109+
embedding: embed.into_owned().into_pyarray(gil.python()).to_owned(),
87110
norm,
88-
)))
111+
}))
89112
} else {
90113
Ok(None)
91114
}
92115
}
93116
}
117+
118+
/// A word and its embedding and embedding norm.
119+
#[pyclass(name=EmbeddingWithNorm)]
120+
pub struct PyEmbeddingWithNorm {
121+
embedding: Py<PyArray1<f32>>,
122+
norm: f32,
123+
word: String,
124+
}
125+
126+
#[pymethods]
127+
impl PyEmbeddingWithNorm {
128+
/// Get the embedding.
129+
#[getter]
130+
pub fn get_embedding(&self) -> Py<PyArray1<f32>> {
131+
let gil = Python::acquire_gil();
132+
self.embedding.clone_ref(gil.python())
133+
}
134+
135+
/// Get the word.
136+
#[getter]
137+
pub fn get_word(&self) -> &str {
138+
&self.word
139+
}
140+
141+
/// Get the norm.
142+
#[getter]
143+
pub fn get_norm(&self) -> f32 {
144+
self.norm
145+
}
146+
}

src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ mod embeddings_wrap;
99
use embeddings_wrap::EmbeddingsWrap;
1010

1111
mod iter;
12-
use iter::{PyEmbeddingIterator, PyEmbeddingWithNormIterator};
12+
use iter::{PyEmbedding, PyEmbeddingIterator, PyEmbeddingWithNorm, PyEmbeddingWithNormIterator};
1313

1414
mod similarity;
1515
use similarity::PyWordSimilarity;
@@ -24,6 +24,8 @@ use vocab::PyVocab;
2424
#[pymodule]
2525
fn finalfusion(_py: Python, m: &PyModule) -> PyResult<()> {
2626
m.add_class::<PyEmbeddings>()?;
27+
m.add_class::<PyEmbedding>()?;
28+
m.add_class::<PyEmbeddingWithNorm>()?;
2729
m.add_class::<PyWordSimilarity>()?;
2830
m.add_class::<PyVocab>()?;
2931
Ok(())

tests/test_embedding.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
def test_embeddings_with_norms(embeddings_fifu, embeddings_text):
1616
for embedding_with_norm, norm in zip(
1717
embeddings_fifu.iter_with_norm(), TEST_NORMS):
18-
unnormed_embed = embedding_with_norm[1] * norm
19-
test_embed = embeddings_text[embedding_with_norm[0]]
18+
unnormed_embed = embedding_with_norm.embedding * norm
19+
test_embed = embeddings_text[embedding_with_norm.word]
2020
assert numpy.allclose(
2121
unnormed_embed, test_embed), "Embedding from 'iter_with_norm()' fails to match!"
22-
assert len(
23-
embedding_with_norm) == 3, "The number of values returned by 'iter_with_norm()' does not match!"
2422

2523

2624
def test_embeddings_with_norms_oov(embeddings_fifu):
@@ -35,13 +33,11 @@ def test_indexing(embeddings_fifu):
3533

3634

3735
def test_embeddings(embeddings_fifu, embeddings_text):
38-
for embedding_with_norm, norm in zip(embeddings_fifu, TEST_NORMS):
39-
unnormed_embed = embedding_with_norm[1] * norm
40-
test_embed = embeddings_text[embedding_with_norm[0]]
36+
for embedding, norm in zip(embeddings_fifu, TEST_NORMS):
37+
unnormed_embed = embedding.embedding * norm
38+
test_embed = embeddings_text[embedding.word]
4139
assert numpy.allclose(
4240
unnormed_embed, test_embed), "Embedding from normal iterator fails to match!"
43-
assert len(
44-
embedding_with_norm) == 2, "The number of values returned by normal iterator does not match!"
4541

4642

4743
def test_embeddings_oov(embeddings_fifu):
@@ -52,4 +48,4 @@ def test_norms(embeddings_fifu):
5248
for embedding_with_norm, norm in zip(
5349
embeddings_fifu.iter_with_norm(), TEST_NORMS):
5450
assert pytest.approx(
55-
embedding_with_norm[2]) == norm, "Norm fails to match!"
51+
embedding_with_norm.norm) == norm, "Norm fails to match!"

0 commit comments

Comments
 (0)