Skip to content

Commit 7a04721

Browse files
NianhengWuDaniël de Kok
authored andcommitted
Expose norms functionality, add pytest unit, fix travis ci
1 parent e0d43b1 commit 7a04721

File tree

8 files changed

+156
-4
lines changed

8 files changed

+156
-4
lines changed

.travis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ install:
2828
if [ "$TRAVIS_OS_NAME" == "osx" ]; then
2929
python3 -m venv venv
3030
source venv/bin/activate
31-
pip install cffi virtualenv pytest
31+
pip install cffi virtualenv pytest numpy
3232
fi
3333
- |
3434
if [ "$TRAVIS_OS_NAME" == "linux" ]; then
3535
python3.6 -m venv venv
3636
source venv/bin/activate
37-
pip install cffi virtualenv pytest
37+
pip install cffi virtualenv pytest numpy
3838
fi
3939
- cargo install pyo3-pack --vers 0.6.1
4040
- rustup default nightly-2019-02-07

src/embeddings.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@ use numpy::{IntoPyArray, PyArray1, PyArray2};
1212
use pyo3::class::iter::PyIterProtocol;
1313
use pyo3::exceptions;
1414
use pyo3::prelude::*;
15+
use pyo3::types::PyTuple;
1516
use toml::{self, Value};
1617

17-
use crate::{EmbeddingsWrap, PyEmbeddingIterator, PyVocab, PyWordSimilarity};
18+
use crate::{
19+
EmbeddingsWrap, PyEmbeddingIterator, PyEmbeddingWithNormIterator, PyVocab, PyWordSimilarity,
20+
};
1821

1922
/// finalfusion embeddings.
2023
#[pyclass(name=Embeddings)]
@@ -126,6 +129,29 @@ impl PyEmbeddings {
126129
}
127130
}
128131

132+
fn embedding_with_norm(&self, word: &str) -> PyResult<Py<PyTuple>> {
133+
let embeddings = self.embeddings.borrow();
134+
135+
use EmbeddingsWrap::*;
136+
let embedding_with_norm = match &*embeddings {
137+
View(e) => e.embedding_with_norm(word),
138+
NonView(e) => e.embedding_with_norm(word),
139+
};
140+
141+
match embedding_with_norm {
142+
Some(embedding_with_norm) => {
143+
let gil = pyo3::Python::acquire_gil();
144+
let py = gil.python();
145+
Ok((
146+
embedding_with_norm.embedding.into_owned().into_pyarray(py),
147+
embedding_with_norm.norm,
148+
)
149+
.into_py(py))
150+
}
151+
None => Err(exceptions::KeyError::py_err("Unknown word and n-grams")),
152+
}
153+
}
154+
129155
/// Copy the entire embeddings matrix.
130156
fn matrix_copy(&self) -> PyResult<Py<PyArray2<f32>>> {
131157
let embeddings = self.embeddings.borrow();
@@ -247,6 +273,10 @@ impl PyEmbeddings {
247273
.map_err(|err| exceptions::IOError::py_err(err.to_string())),
248274
}
249275
}
276+
277+
fn iter_with_norm(&self) -> PyResult<PyEmbeddingWithNormIterator> {
278+
Ok(PyEmbeddingWithNormIterator::new(self.embeddings.clone(), 0))
279+
}
250280
}
251281

252282
#[pyproto]

src/embeddings_wrap.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use finalfusion::norms::NdNorms;
12
use finalfusion::prelude::*;
23

34
pub enum EmbeddingsWrap {
@@ -21,4 +22,12 @@ impl EmbeddingsWrap {
2122
View(e) => e.vocab(),
2223
}
2324
}
25+
26+
pub fn norms(&self) -> Option<&NdNorms> {
27+
use EmbeddingsWrap::*;
28+
match self {
29+
NonView(e) => e.norms(),
30+
View(e) => e.norms(),
31+
}
32+
}
2433
}

src/iter.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,46 @@ impl PyIterProtocol for PyEmbeddingIterator {
4848
}
4949
}
5050
}
51+
52+
#[pyclass(name=EmbeddingWithNormIterator)]
53+
pub struct PyEmbeddingWithNormIterator {
54+
embeddings: Rc<RefCell<EmbeddingsWrap>>,
55+
idx: usize,
56+
}
57+
58+
impl PyEmbeddingWithNormIterator {
59+
pub fn new(embeddings: Rc<RefCell<EmbeddingsWrap>>, idx: usize) -> Self {
60+
PyEmbeddingWithNormIterator { embeddings, idx }
61+
}
62+
}
63+
64+
#[pyproto]
65+
impl PyIterProtocol for PyEmbeddingWithNormIterator {
66+
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<PyEmbeddingWithNormIterator>> {
67+
Ok(slf.into())
68+
}
69+
70+
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<(String, Py<PyArray1<f32>>, f32)>> {
71+
let slf = &mut *slf;
72+
73+
let embeddings = slf.embeddings.borrow();
74+
let vocab = embeddings.vocab();
75+
76+
if slf.idx < vocab.len() {
77+
let word = vocab.words()[slf.idx].to_string();
78+
let embed = embeddings.storage().embedding(slf.idx);
79+
let norm = embeddings.norms().map(|n| n.0[slf.idx]).unwrap_or(1.);
80+
81+
slf.idx += 1;
82+
83+
let gil = pyo3::Python::acquire_gil();
84+
Ok(Some((
85+
word,
86+
embed.into_owned().into_pyarray(gil.python()).to_owned(),
87+
norm,
88+
)))
89+
} else {
90+
Ok(None)
91+
}
92+
}
93+
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ mod embeddings_wrap;
99
use embeddings_wrap::EmbeddingsWrap;
1010

1111
mod iter;
12-
use iter::PyEmbeddingIterator;
12+
use iter::{PyEmbeddingIterator, PyEmbeddingWithNormIterator};
1313

1414
mod similarity;
1515
use similarity::PyWordSimilarity;

tests/embeddings.fifu

460 Bytes
Binary file not shown.

tests/embeddings.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
one 3.0 1.0 0.0 0.0 0.0 0.0 2.0 2.0 4.0 3.0
2+
two 2.0 3.0 3.0 3.0 3.0 2.0 0.0 3.0 3.0 4.0
3+
three 0.0 0.0 2.0 0.0 2.0 1.0 2.0 4.0 0.0 3.0
4+
four 1.0 4.0 4.0 2.0 4.0 2.0 4.0 1.0 3.0 1.0
5+
five 0.0 4.0 1.0 2.0 0.0 4.0 0.0 3.0 1.0 3.0
6+
six 3.0 3.0 4.0 2.0 0.0 0.0 0.0 3.0 2.0 1.0
7+
seven 1.0 4.0 0.0 2.0 2.0 2.0 4.0 3.0 1.0 1.0

tests/test_embedding.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import finalfusion
2+
import pytest
3+
import numpy
4+
5+
TEST_NORMS = [
6+
6.557438373565674,
7+
8.83176040649414,
8+
6.164413928985596,
9+
9.165151596069336,
10+
7.4833149909973145,
11+
7.211102485656738,
12+
7.4833149909973145
13+
]
14+
15+
16+
def test_embeddings_with_norms():
17+
embeds = finalfusion.Embeddings(
18+
"tests/embeddings.fifu")
19+
embeds_dict = dict()
20+
with open("tests/embeddings.txt", "r", encoding="utf8") as lines:
21+
for line in lines:
22+
line_list = line.split(' ')
23+
embeds_dict[line_list[0]] = [float(val) for val in line_list[1:]]
24+
25+
for embedding_with_norm, norm in zip(embeds.iter_with_norm(), TEST_NORMS):
26+
unnormed_embed = embedding_with_norm[1] * norm
27+
test_embed = embeds_dict[embedding_with_norm[0]]
28+
assert numpy.allclose(
29+
unnormed_embed, test_embed), "Embedding from 'iter_with_norm()' fails to match!"
30+
assert len(
31+
embedding_with_norm) == 3, "The number of values returned by 'iter_with_norm()' does not match!"
32+
33+
34+
def test_embeddings():
35+
embeds = finalfusion.Embeddings(
36+
"tests/embeddings.fifu")
37+
embeds_dict = dict()
38+
with open("tests/embeddings.txt", "r", encoding="utf8") as lines:
39+
for line in lines:
40+
line_list = line.split(' ')
41+
embeds_dict[line_list[0]] = [float(i) for i in line_list[1:]]
42+
43+
for embedding_with_norm, norm in zip(embeds, TEST_NORMS):
44+
unnormed_embed = embedding_with_norm[1] * norm
45+
test_embed = embeds_dict[embedding_with_norm[0]]
46+
assert numpy.allclose(
47+
unnormed_embed, test_embed), "Embedding from normal iterator fails to match!"
48+
assert len(
49+
embedding_with_norm) == 2, "The number of values returned by normal iterator does not match!"
50+
51+
52+
def test_norms():
53+
embeds = finalfusion.Embeddings(
54+
"tests/embeddings.fifu")
55+
embeds_dict = dict()
56+
with open("tests/embeddings.txt", "r", encoding="utf8") as lines:
57+
for line in lines:
58+
line_list = line.split(' ')
59+
embeds_dict[line_list[0]] = [float(val) for val in line_list[1:]]
60+
61+
for embedding_with_norm, norm in zip(embeds.iter_with_norm(), TEST_NORMS):
62+
assert pytest.approx(
63+
embedding_with_norm[2] == norm), "Norm fails to match!"

0 commit comments

Comments
 (0)