Skip to content

Commit d0da7fe

Browse files
danieldkDaniël de Kok
authored andcommitted
Add the static read_{text,text_dims,word2vec} methods
1 parent 11980ac commit d0da7fe

File tree

4 files changed

+95
-18
lines changed

4 files changed

+95
-18
lines changed

src/embeddings.rs

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ use std::io::{BufReader, BufWriter};
44
use std::rc::Rc;
55

66
use failure::Error;
7+
use finalfusion::io as ffio;
78
use finalfusion::metadata::Metadata;
89
use finalfusion::prelude::*;
910
use finalfusion::similarity::*;
11+
use finalfusion::text::{ReadText, ReadTextDims};
12+
use finalfusion::word2vec::ReadWord2Vec;
1013
use itertools::Itertools;
1114
use ndarray::Array2;
1215
use numpy::{IntoPyArray, PyArray1, PyArray2};
@@ -44,9 +47,9 @@ impl PyEmbeddings {
4447
// First try to load embeddings with viewable storage. If that
4548
// fails, attempt to load the embeddings as non-viewable
4649
// storage.
47-
let embeddings = match load_embeddings(path, mmap) {
50+
let embeddings = match read_embeddings(path, mmap) {
4851
Ok(e) => Rc::new(RefCell::new(EmbeddingsWrap::View(e))),
49-
Err(_) => load_embeddings(path, mmap)
52+
Err(_) => read_embeddings(path, mmap)
5053
.map(|e| Rc::new(RefCell::new(EmbeddingsWrap::NonView(e))))
5154
.map_err(|err| exceptions::IOError::py_err(err.to_string()))?,
5255
};
@@ -56,6 +59,42 @@ impl PyEmbeddings {
5659
Ok(())
5760
}
5861

62+
/// read_text(path,/)
63+
/// --
64+
///
65+
/// Read embeddings in text format. This format uses one line per
66+
/// embedding. Each line starts with the word in UTF-8, followed
67+
/// by its vector components encoded in ASCII. The word and its
68+
/// components are separated by spaces.
69+
#[staticmethod]
70+
fn read_text(path: &str) -> PyResult<PyEmbeddings> {
71+
read_non_fifu_embeddings(path, |r| Embeddings::read_text(r))
72+
}
73+
74+
/// read_text_dims(path,/)
75+
/// --
76+
///
77+
/// Read embeddings in text format with dimensions. In this format,
78+
/// the first line states the shape of the embedding matrix. The
79+
/// number of rows (words) and columns (embedding dimensionality) is
80+
/// separated by a space character. The remainder of the file uses
81+
/// one line per embedding. Each line starts with the word in UTF-8,
82+
/// followed by its vector components encoded in ASCII. The word and
83+
/// its components are separated by spaces.
84+
#[staticmethod]
85+
fn read_text_dims(path: &str) -> PyResult<PyEmbeddings> {
86+
read_non_fifu_embeddings(path, |r| Embeddings::read_text_dims(r))
87+
}
88+
89+
/// read_text_dims(path,/)
90+
/// --
91+
///
92+
/// Read embeddings in the word2vec binary format.
93+
#[staticmethod]
94+
fn read_word2vec(path: &str) -> PyResult<PyEmbeddings> {
95+
read_non_fifu_embeddings(path, |r| Embeddings::read_word2vec_binary(r))
96+
}
97+
5998
/// Get the model's vocabulary.
6099
fn vocab(&self) -> PyResult<PyVocab> {
61100
Ok(PyVocab::new(self.embeddings.clone()))
@@ -283,7 +322,7 @@ impl PyIterProtocol for PyEmbeddings {
283322
}
284323
}
285324

286-
fn load_embeddings<S>(path: &str, mmap: bool) -> Result<Embeddings<VocabWrap, S>, Error>
325+
fn read_embeddings<S>(path: &str, mmap: bool) -> Result<Embeddings<VocabWrap, S>, Error>
287326
where
288327
Embeddings<VocabWrap, S>: ReadEmbeddings + MmapEmbeddings,
289328
{
@@ -298,3 +337,27 @@ where
298337

299338
Ok(embeddings)
300339
}
340+
341+
fn read_non_fifu_embeddings<R>(path: &str, read_embeddings: R) -> PyResult<PyEmbeddings>
342+
where
343+
R: FnOnce(&mut BufReader<File>) -> ffio::Result<Embeddings<SimpleVocab, NdArray>>,
344+
{
345+
let f = File::open(path).map_err(|err| {
346+
exceptions::IOError::py_err(format!(
347+
"Cannot read text embeddings from '{}': {}'",
348+
path, err
349+
))
350+
})?;
351+
let mut reader = BufReader::new(f);
352+
353+
let embeddings = read_embeddings(&mut reader).map_err(|err| {
354+
exceptions::IOError::py_err(format!(
355+
"Cannot read text embeddings from '{}': {}'",
356+
path, err
357+
))
358+
})?;
359+
360+
Ok(PyEmbeddings {
361+
embeddings: Rc::new(RefCell::new(EmbeddingsWrap::View(embeddings.into()))),
362+
})
363+
}

tests/conftest.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,19 @@ def embeddings_fifu(tests_root):
1717

1818
@pytest.fixture
1919
def embeddings_text(tests_root):
20-
embeds = dict()
21-
22-
with open(os.path.join(tests_root, "embeddings.txt"), "r", encoding="utf8") as lines:
23-
for line in lines:
24-
line_list = line.split(' ')
25-
embeds[line_list[0]] = numpy.array(
26-
[float(c) for c in line_list[1:]])
27-
28-
yield embeds
20+
yield finalfusion.Embeddings.read_text(os.path.join(tests_root, "embeddings.txt"))
2921

3022

3123
@pytest.fixture
3224
def similarity_fifu(tests_root):
3325
yield finalfusion.Embeddings(os.path.join(tests_root, "similarity.fifu"))
3426

3527

28+
@pytest.fixture
29+
def embeddings_text_dims(tests_root):
30+
yield finalfusion.Embeddings.read_text_dims(os.path.join(tests_root, "embeddings.dims.txt"))
31+
32+
3633
@pytest.fixture
3734
def tests_root():
3835
yield os.path.dirname(__file__)

tests/embeddings.dims.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
7 10
2+
one 3.0 1.0 0.0 0.0 0.0 0.0 2.0 2.0 4.0 3.0
3+
two 2.0 3.0 3.0 3.0 3.0 2.0 0.0 3.0 3.0 4.0
4+
three 0.0 0.0 2.0 0.0 2.0 1.0 2.0 4.0 0.0 3.0
5+
four 1.0 4.0 4.0 2.0 4.0 2.0 4.0 1.0 3.0 1.0
6+
five 0.0 4.0 1.0 2.0 0.0 4.0 0.0 3.0 1.0 3.0
7+
six 3.0 3.0 4.0 2.0 0.0 0.0 0.0 3.0 2.0 1.0
8+
seven 1.0 4.0 0.0 2.0 2.0 2.0 4.0 3.0 1.0 1.0

tests/test_embedding.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,22 @@
1212
]
1313

1414

15-
def test_embeddings(embeddings_fifu, embeddings_text):
16-
for embedding, norm in zip(
17-
embeddings_fifu, TEST_NORMS):
18-
unnormed_embed = embedding.embedding * embedding.norm
19-
test_embed = embeddings_text[embedding.word]
15+
def test_embeddings(embeddings_fifu, embeddings_text, embeddings_text_dims):
16+
# Check that we cover all words from all embedding below.
17+
assert len(embeddings_fifu.vocab()) == 7
18+
assert len(embeddings_text.vocab()) == 7
19+
assert len(embeddings_text_dims.vocab()) == 7
20+
21+
# Check that the finalfusion embeddings have the correct dimensionality
22+
# The correct dimensionality of the other embedding types is asserted
23+
# in the pairwise comparisons below.
24+
assert embeddings_fifu.matrix_copy().shape == (7, 10)
25+
26+
for embedding in embeddings_fifu:
27+
assert numpy.allclose(
28+
embedding.embedding, embeddings_text[embedding.word]), "FiFu and text embedding mismatch"
2029
assert numpy.allclose(
21-
unnormed_embed, test_embed), "Embedding from 'iter_with_norm()' fails to match!"
30+
embedding.embedding, embeddings_text_dims[embedding.word]), "FiFu and textdims embedding mismatch"
2231

2332

2433
def test_embeddings_with_norms_oov(embeddings_fifu):

0 commit comments

Comments
 (0)