Skip to content

Commit 3ec049c

Browse files
NianhengWuDaniël de Kok
authored andcommitted
Expose ngram_indices and subword_indices
1 parent 61a8c62 commit 3ec049c

File tree

5 files changed

+100
-2
lines changed

5 files changed

+100
-2
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ features = ["extension-module"]
2121
[dependencies]
2222
itertools = "0.8"
2323
failure = "0.1"
24-
finalfusion = "0.8"
24+
finalfusion = "0.8.2"
2525
libc = "0.2"
2626
ndarray = "0.12"
2727
numpy = "0.6"

src/vocab.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::cell::RefCell;
22
use std::rc::Rc;
33

4-
use finalfusion::chunks::vocab::WordIndex;
4+
use finalfusion::chunks::vocab::{NGramIndices, SubwordIndices, VocabWrap, WordIndex};
55
use finalfusion::prelude::*;
66
use pyo3::class::sequence::PySequenceProtocol;
77
use pyo3::exceptions;
@@ -34,6 +34,28 @@ impl PyVocab {
3434
}
3535
})
3636
}
37+
38+
fn ngram_indices(&self, word: &str) -> PyResult<Option<Vec<(String, usize)>>> {
39+
let embeds = self.embeddings.borrow();
40+
match embeds.vocab() {
41+
VocabWrap::FastTextSubwordVocab(inner) => Ok(inner.ngram_indices(word)),
42+
VocabWrap::FinalfusionSubwordVocab(inner) => Ok(inner.ngram_indices(word)),
43+
VocabWrap::SimpleVocab(_) => Err(exceptions::ValueError::py_err(
44+
"querying n-gram indices is not supported for this vocabulary",
45+
)),
46+
}
47+
}
48+
49+
fn subword_indices(&self, word: &str) -> PyResult<Option<Vec<usize>>> {
50+
let embeds = self.embeddings.borrow();
51+
match embeds.vocab() {
52+
VocabWrap::FastTextSubwordVocab(inner) => Ok(inner.subword_indices(word)),
53+
VocabWrap::FinalfusionSubwordVocab(inner) => Ok(inner.subword_indices(word)),
54+
VocabWrap::SimpleVocab(_) => Err(exceptions::ValueError::py_err(
55+
"querying subwords' indices is not supported for this vocabulary",
56+
)),
57+
}
58+
}
3759
}
3860

3961
#[pyproto]

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ def similarity_fifu(tests_root):
2525
yield finalfusion.Embeddings(os.path.join(tests_root, "similarity.fifu"))
2626

2727

28+
@pytest.fixture
29+
def subword_fifu(tests_root):
30+
yield finalfusion.Embeddings(os.path.join(tests_root, "subword.fifu"))
31+
32+
2833
@pytest.fixture
2934
def embeddings_text_dims(tests_root):
3035
yield finalfusion.Embeddings.read_text_dims(os.path.join(tests_root, "embeddings.dims.txt"))

tests/subword.fifu

20.5 KB
Binary file not shown.

tests/test_vocab.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,74 @@
1+
TEST_NGRAM_INDICES = [
2+
('tüb',
3+
14),
4+
('en>',
5+
69),
6+
('übinge',
7+
74),
8+
('gen',
9+
124),
10+
('ing',
11+
168),
12+
('ngen',
13+
181),
14+
('bing',
15+
197),
16+
('inge',
17+
246),
18+
('übin',
19+
250),
20+
('tübi',
21+
276),
22+
('bingen',
23+
300),
24+
('<tübin',
25+
308),
26+
('bin',
27+
325),
28+
('übing',
29+
416),
30+
('gen>',
31+
549),
32+
('ngen>',
33+
590),
34+
('ingen>',
35+
648),
36+
('tübing',
37+
651),
38+
('übi',
39+
707),
40+
('ingen',
41+
717),
42+
('binge',
43+
761),
44+
('<tübi',
45+
817),
46+
('<tü',
47+
820),
48+
('<tüb',
49+
857),
50+
('nge',
51+
860),
52+
('tübin',
53+
1007)]
54+
55+
156
def test_embeddings_with_norms_oov(embeddings_fifu):
257
vocab = embeddings_fifu.vocab()
358
assert vocab.item_to_indices("Something out of vocabulary") is None
59+
60+
61+
def test_ngram_indices(subword_fifu):
62+
vocab = subword_fifu.vocab()
63+
ngram_indices = sorted(vocab.ngram_indices("tübingen"), key=lambda tup: tup[1])
64+
for ngram_index, test_ngram_index in zip(
65+
ngram_indices, TEST_NGRAM_INDICES):
66+
assert ngram_index == test_ngram_index
67+
68+
69+
def test_subword_indices(subword_fifu):
70+
vocab = subword_fifu.vocab()
71+
subword_indices = sorted(vocab.subword_indices("tübingen"))
72+
for subword_index, test_ngram_index in zip(
73+
subword_indices, TEST_NGRAM_INDICES):
74+
assert subword_index == test_ngram_index[1]

0 commit comments

Comments
 (0)