Skip to content

Commit 1b16a2b

Browse files
sebpuetzDaniël de Kok
authored andcommitted
Add masked analogy functionality.
1 parent d658c43 commit 1b16a2b

File tree

5 files changed

+75
-9
lines changed

5 files changed

+75
-9
lines changed

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ features = ["extension-module"]
2020

2121
[dependencies]
2222
failure = "0.1"
23-
finalfusion = "0.6"
23+
finalfusion = "0.7"
2424
libc = "0.2"
2525
ndarray = "0.12"
2626
numpy = "0.5"

src/embeddings.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,15 @@ impl PyEmbeddings {
6363
///
6464
/// This returns words for the analogy query *w1* is to *w2*
6565
/// as *w3* is to ?.
66-
#[args(limit = 10)]
66+
#[args(limit = 10, mask = "(true, true, true)")]
6767
fn analogy(
6868
&self,
6969
py: Python,
7070
word1: &str,
7171
word2: &str,
7272
word3: &str,
7373
limit: usize,
74+
mask: (bool, bool, bool),
7475
) -> PyResult<Vec<PyObject>> {
7576
use EmbeddingsWrap::*;
7677
let embeddings = self.embeddings.borrow();
@@ -83,10 +84,11 @@ impl PyEmbeddings {
8384
}
8485
};
8586

86-
let results = match embeddings.analogy(word1, word2, word3, limit) {
87-
Some(results) => results,
88-
None => return Err(exceptions::KeyError::py_err("Unknown word or n-grams")),
89-
};
87+
let results =
88+
match embeddings.analogy_masked(word1, word2, word3, limit, [mask.0, mask.1, mask.2]) {
89+
Some(results) => results,
90+
None => return Err(exceptions::KeyError::py_err("Unknown word or n-grams")),
91+
};
9092

9193
let mut r = Vec::with_capacity(results.len());
9294
for ws in results {

tests/analogy.fifu

17.4 KB
Binary file not shown.

tests/test_analogy.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import finalfusion
2+
3+
ANALOGY_ORDER = [
4+
"Deutschland",
5+
"Westdeutschland",
6+
"Sachsen",
7+
"Mitteldeutschland",
8+
"Brandenburg",
9+
"Polen",
10+
"Norddeutschland",
11+
"Dänemark",
12+
"Schleswig-Holstein",
13+
"Österreich",
14+
"Bayern",
15+
"Thüringen",
16+
"Bundesrepublik",
17+
"Ostdeutschland",
18+
"Preußen",
19+
"Deutschen",
20+
"Hessen",
21+
"Potsdam",
22+
"Mecklenburg",
23+
"Niedersachsen",
24+
"Hamburg",
25+
"Süddeutschland",
26+
"Bremen",
27+
"Russland",
28+
"Deutschlands",
29+
"BRD",
30+
"Litauen",
31+
"Mecklenburg-Vorpommern",
32+
"DDR",
33+
"West-Berlin",
34+
"Saarland",
35+
"Lettland",
36+
"Hannover",
37+
"Rostock",
38+
"Sachsen-Anhalt",
39+
"Pommern",
40+
"Schweden",
41+
"Deutsche",
42+
"deutschen",
43+
"Westfalen",
44+
]
45+
46+
def test_analogies():
47+
embeds = finalfusion.Embeddings('tests/analogy.fifu')
48+
for idx, analogy in enumerate(embeds.analogy("Paris", "Frankreich", "Berlin", 40)):
49+
assert ANALOGY_ORDER[idx] == analogy.word
50+
51+
assert embeds.analogy("Paris", "Frankreich", "Paris", 1, (True, False, True))[0].word == "Frankreich"
52+
assert embeds.analogy("Paris", "Frankreich", "Paris", 1, (True, True, True))[0].word != "Frankreich"
53+
assert embeds.analogy("Frankreich", "Frankreich", "Frankreich", 1, (False, False, False))[0].word == "Frankreich"
54+
assert embeds.analogy("Frankreich", "Frankreich", "Frankreich", 1, (False, False, True))[0].word != "Frankreich"
55+
try:
56+
embeds.analogy("Paris", "Frankreich", "Paris", 1, (True, True))
57+
assert True == False
58+
except:
59+
()
60+
try:
61+
embeds.analogy("Paris", "Frankreich", "Paris", 1, (True, True, True, True))
62+
assert True == False
63+
except:
64+
()

0 commit comments

Comments
 (0)