Skip to content

Commit 696007f

Browse files
danieldksebpuetz
authored andcommitted
Remove the specialized embedding iterator with norms
Instead, let the normal embeddings iterator return norms. Getting the norms is a simple, predictable, memory traversal, and thus very cheap.
1 parent 21e071c commit 696007f

File tree

4 files changed

+14
-93
lines changed

4 files changed

+14
-93
lines changed

src/embeddings.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ use pyo3::types::PyTuple;
1515
use pyo3::{exceptions, PyMappingProtocol};
1616
use toml::{self, Value};
1717

18-
use crate::{
19-
EmbeddingsWrap, PyEmbeddingIterator, PyEmbeddingWithNormIterator, PyVocab, PyWordSimilarity,
20-
};
18+
use crate::{EmbeddingsWrap, PyEmbeddingIterator, PyVocab, PyWordSimilarity};
2119

2220
/// finalfusion embeddings.
2321
#[pyclass(name=Embeddings)]
@@ -257,10 +255,6 @@ impl PyEmbeddings {
257255
.map_err(|err| exceptions::IOError::py_err(err.to_string())),
258256
}
259257
}
260-
261-
fn iter_with_norm(&self) -> PyResult<PyEmbeddingWithNormIterator> {
262-
Ok(PyEmbeddingWithNormIterator::new(self.embeddings.clone(), 0))
263-
}
264258
}
265259

266260
#[pyproto]

src/iter.rs

Lines changed: 4 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -32,70 +32,6 @@ impl PyIterProtocol for PyEmbeddingIterator {
3232
let embeddings = slf.embeddings.borrow();
3333
let vocab = embeddings.vocab();
3434

35-
if slf.idx < vocab.len() {
36-
let word = vocab.words()[slf.idx].to_string();
37-
let embed = embeddings.storage().embedding(slf.idx);
38-
39-
slf.idx += 1;
40-
41-
let gil = pyo3::Python::acquire_gil();
42-
Ok(Some(PyEmbedding {
43-
word,
44-
embedding: embed.into_owned().into_pyarray(gil.python()).to_owned(),
45-
}))
46-
} else {
47-
Ok(None)
48-
}
49-
}
50-
}
51-
52-
/// A word and its embedding.
53-
#[pyclass(name=Embedding)]
54-
pub struct PyEmbedding {
55-
embedding: Py<PyArray1<f32>>,
56-
word: String,
57-
}
58-
59-
#[pymethods]
60-
impl PyEmbedding {
61-
/// Get the embedding.
62-
#[getter]
63-
pub fn get_embedding(&self) -> Py<PyArray1<f32>> {
64-
let gil = Python::acquire_gil();
65-
self.embedding.clone_ref(gil.python())
66-
}
67-
68-
/// Get the word.
69-
#[getter]
70-
pub fn get_word(&self) -> &str {
71-
&self.word
72-
}
73-
}
74-
75-
#[pyclass(name=EmbeddingWithNormIterator)]
76-
pub struct PyEmbeddingWithNormIterator {
77-
embeddings: Rc<RefCell<EmbeddingsWrap>>,
78-
idx: usize,
79-
}
80-
81-
impl PyEmbeddingWithNormIterator {
82-
pub fn new(embeddings: Rc<RefCell<EmbeddingsWrap>>, idx: usize) -> Self {
83-
PyEmbeddingWithNormIterator { embeddings, idx }
84-
}
85-
}
86-
87-
#[pyproto]
88-
impl PyIterProtocol for PyEmbeddingWithNormIterator {
89-
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<PyEmbeddingWithNormIterator>> {
90-
Ok(slf.into())
91-
}
92-
93-
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<PyEmbeddingWithNorm>> {
94-
let slf = &mut *slf;
95-
96-
let embeddings = slf.embeddings.borrow();
97-
let vocab = embeddings.vocab();
98-
9935
if slf.idx < vocab.len() {
10036
let word = vocab.words()[slf.idx].to_string();
10137
let embed = embeddings.storage().embedding(slf.idx);
@@ -104,7 +40,7 @@ impl PyIterProtocol for PyEmbeddingWithNormIterator {
10440
slf.idx += 1;
10541

10642
let gil = pyo3::Python::acquire_gil();
107-
Ok(Some(PyEmbeddingWithNorm {
43+
Ok(Some(PyEmbedding {
10844
word,
10945
embedding: embed.into_owned().into_pyarray(gil.python()).to_owned(),
11046
norm,
@@ -116,15 +52,15 @@ impl PyIterProtocol for PyEmbeddingWithNormIterator {
11652
}
11753

11854
/// A word and its embedding and embedding norm.
119-
#[pyclass(name=EmbeddingWithNorm)]
120-
pub struct PyEmbeddingWithNorm {
55+
#[pyclass(name=Embedding)]
56+
pub struct PyEmbedding {
12157
embedding: Py<PyArray1<f32>>,
12258
norm: f32,
12359
word: String,
12460
}
12561

12662
#[pymethods]
127-
impl PyEmbeddingWithNorm {
63+
impl PyEmbedding {
12864
/// Get the embedding.
12965
#[getter]
13066
pub fn get_embedding(&self) -> Py<PyArray1<f32>> {

src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ mod embeddings_wrap;
99
use embeddings_wrap::EmbeddingsWrap;
1010

1111
mod iter;
12-
use iter::{PyEmbedding, PyEmbeddingIterator, PyEmbeddingWithNorm, PyEmbeddingWithNormIterator};
12+
use iter::{PyEmbedding, PyEmbeddingIterator};
1313

1414
mod similarity;
1515
use similarity::PyWordSimilarity;
@@ -25,7 +25,6 @@ use vocab::PyVocab;
2525
fn finalfusion(_py: Python, m: &PyModule) -> PyResult<()> {
2626
m.add_class::<PyEmbeddings>()?;
2727
m.add_class::<PyEmbedding>()?;
28-
m.add_class::<PyEmbeddingWithNorm>()?;
2928
m.add_class::<PyWordSimilarity>()?;
3029
m.add_class::<PyVocab>()?;
3130
Ok(())

tests/test_embedding.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
]
1313

1414

15-
def test_embeddings_with_norms(embeddings_fifu, embeddings_text):
16-
for embedding_with_norm, norm in zip(
17-
embeddings_fifu.iter_with_norm(), TEST_NORMS):
18-
unnormed_embed = embedding_with_norm.embedding * norm
19-
test_embed = embeddings_text[embedding_with_norm.word]
15+
def test_embeddings(embeddings_fifu, embeddings_text):
16+
for embedding, norm in zip(
17+
embeddings_fifu, TEST_NORMS):
18+
unnormed_embed = embedding.embedding * embedding.norm
19+
test_embed = embeddings_text[embedding.word]
2020
assert numpy.allclose(
2121
unnormed_embed, test_embed), "Embedding from 'iter_with_norm()' fails to match!"
2222

@@ -32,20 +32,12 @@ def test_indexing(embeddings_fifu):
3232
embeddings_fifu["Something out of vocabulary"]
3333

3434

35-
def test_embeddings(embeddings_fifu, embeddings_text):
36-
for embedding, norm in zip(embeddings_fifu, TEST_NORMS):
37-
unnormed_embed = embedding.embedding * norm
38-
test_embed = embeddings_text[embedding.word]
39-
assert numpy.allclose(
40-
unnormed_embed, test_embed), "Embedding from normal iterator fails to match!"
41-
42-
4335
def test_embeddings_oov(embeddings_fifu):
4436
assert embeddings_fifu.embedding("Something out of vocabulary") is None
4537

4638

4739
def test_norms(embeddings_fifu):
48-
for embedding_with_norm, norm in zip(
49-
embeddings_fifu.iter_with_norm(), TEST_NORMS):
40+
for embedding, norm in zip(
41+
embeddings_fifu, TEST_NORMS):
5042
assert pytest.approx(
51-
embedding_with_norm.norm) == norm, "Norm fails to match!"
43+
embedding.norm) == norm, "Norm fails to match!"

0 commit comments

Comments
 (0)