Skip to content

Commit 79be265

Browse files
committed
Implement Analogy and Similarity Queries.
1 parent 4dbfb50 commit 79be265

File tree

6 files changed

+322
-11
lines changed

6 files changed

+322
-11
lines changed

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def run(self):
7171
["src/finalfusion/subword/explicit_indexer.c"])
7272
extensions = [hash_indexers, ngrams, explicit_indexer]
7373

74+
install_requires = ["numpy", "toml"]
75+
if sys.version_info.major == 3 and sys.version_info.minor == 6:
76+
install_requires.append("dataclasses")
77+
7478
setup(name='finalfusion',
7579
author="Sebastian Pütz <[email protected]>, Daniël de Kok <[email protected]>",
7680
classifiers=[
@@ -81,7 +85,7 @@ def run(self):
8185
cmdclass={'build_ext': cython_build_ext},
8286
description="Interface to finalfusion embeddings",
8387
ext_modules=extensions,
84-
install_requires=["numpy", "toml"],
88+
install_requires=install_requires,
8589
license='BlueOak-1.0.0',
8690
packages=find_packages('src'),
8791
include_package_data=True,

src/finalfusion/embeddings.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""
22
Finalfusion Embeddings
33
"""
4+
import heapq
5+
from dataclasses import field, dataclass
46
from os import PathLike
5-
from typing import Optional, Tuple, List, Union, Any, Iterator
7+
from typing import Optional, Tuple, List, Union, Any, Iterator, Set
68

79
import numpy as np
810

@@ -418,6 +420,120 @@ def bucket_to_explicit(self) -> 'Embeddings':
418420
storage=NdArray(storage),
419421
norms=self.norms)
420422

423+
def analogy( # pylint: disable=too-many-arguments
424+
self,
425+
word1: str,
426+
word2: str,
427+
word3: str,
428+
k: int = 1,
429+
skip: Set[str] = None) -> Optional[List['SimilarityResult']]:
430+
"""
431+
Perform an analogy query.
432+
433+
This method returns words that are close in vector space the analogy
434+
query `word1` is to `word2` as `word3` is to `?`. More concretely,
435+
it searches embeddings that are similar to:
436+
437+
*embedding(word2) - embedding(word1) + embedding(word3)*
438+
439+
Words specified in ``skip`` are not considered as answers. If ``skip``
440+
is None, the query words ``word1``, ``word2`` and ``word3`` are
441+
excluded.
442+
443+
At most, ``k`` results are returned. ``None`` is returned when no
444+
embedding could be computed for any of the tokens.
445+
446+
Parameters
447+
----------
448+
word1 : str
449+
Word1 is to...
450+
word2 : str
451+
word2 like...
452+
word3 : str
453+
word3 is to the return value
454+
skip : Set[str]
455+
Set of strings which should not be considered as answers. Defaults
456+
to ``None`` which excludes the query strings. To allow the query
457+
strings as answers, pass an empty set.
458+
k : int
459+
Number of answers to return, defaults to 1.
460+
461+
Returns
462+
-------
463+
answers : List[SimilarityResult]
464+
List of answers.
465+
"""
466+
embed_a = self.embedding(word1)
467+
embed_b = self.embedding(word2)
468+
embed_c = self.embedding(word3)
469+
if embed_a is None or embed_b is None or embed_c is None:
470+
return None
471+
diff = embed_b - embed_a
472+
embed_d = embed_c + diff
473+
embed_d /= np.linalg.norm(embed_d)
474+
return self._similarity(
475+
embed_d, k, {word1, word2, word3} if skip is None else skip)
476+
477+
def word_similarity(self, query: str,
478+
k: int = 10) -> Optional[List['SimilarityResult']]:
479+
"""
480+
Retrieves the nearest neighbors of the query string.
481+
482+
The similarity between the embedding of the query and other embeddings
483+
is defined by the dot product of the embeddings. If the vectors are
484+
unit vectors, this is the cosine similarity.
485+
486+
At most, ``k`` results are returned.
487+
488+
Parameters
489+
----------
490+
query : str
491+
The query string
492+
k : int
493+
The number of neighbors to return, defaults to 10.
494+
495+
Returns
496+
-------
497+
neighbours : List[Tuple[str, float], optional
498+
List of tuples with neighbour and similarity measure. None if no
499+
embedding can be found for ``query``.
500+
"""
501+
embed = self.embedding(query)
502+
if embed is None:
503+
return None
504+
return self._similarity(embed, k, {query})
505+
506+
def embedding_similarity(self,
507+
query: np.ndarray,
508+
k: int = 10,
509+
skip: Optional[Set[str]] = None
510+
) -> Optional[List['SimilarityResult']]:
511+
"""
512+
Retrieves the nearest neighbors of the query embedding.
513+
514+
The similarity between the query embedding and other embeddings is
515+
defined by the dot product of the embeddings. If the vectors are unit
516+
vectors, this is the cosine similarity.
517+
518+
At most, ``k`` results are returned.
519+
520+
Parameters
521+
----------
522+
query : str
523+
The query array.
524+
k : int
525+
The number of neighbors to return, defaults to 10.
526+
skip : Set[str], optional
527+
Set of strings that should not be considered as neighbours.
528+
529+
Returns
530+
-------
531+
neighbours : List[Tuple[str, float], optional
532+
List of tuples with neighbour and similarity measure. None if no
533+
embedding can be found for ``query``.
534+
"""
535+
return self._similarity(query, k, set() if skip is None else skip)
536+
421537
def __contains__(self, item):
422538
return item in self._vocab
423539

@@ -427,6 +543,24 @@ def __iter__(self) -> Union[Iterator[Tuple[str, np.ndarray]], Iterator[
427543
return zip(self._vocab, self._storage, self._norms)
428544
return zip(self._vocab, self._storage)
429545

546+
def _similarity(self, query: np.ndarray, k: int,
547+
skips: Set[str]) -> List['SimilarityResult']:
548+
words = self.storage[:len(self.vocab)] # type: np.ndarray
549+
sims = words.dot(query)
550+
skip_indices = set(skip for skip in (self.vocab.word_index.get(skip)
551+
for skip in skips)
552+
if skip is not None)
553+
partition = sims.argpartition(-k -
554+
len(skip_indices))[-k -
555+
len(skip_indices):]
556+
557+
heap = [] # type: List[SimilarityResult]
558+
for idx in partition:
559+
if idx not in skip_indices:
560+
heapq.heappush(
561+
heap, SimilarityResult(self.vocab.words[idx], sims[idx]))
562+
return heapq.nlargest(k, heap)
563+
430564
def _embedding(self,
431565
idx: Union[int, List[int]],
432566
out: Optional[np.ndarray] = None
@@ -524,3 +658,14 @@ def load_finalfusion(file: Union[str, bytes, int, PathLike],
524658
f'Expected norms chunk, not {str(chunk_id)}')
525659

526660
return Embeddings(storage, vocab, norms, metadata)
661+
662+
663+
@dataclass(order=True)
664+
class SimilarityResult:
665+
"""
666+
Container for a Similarity result.
667+
668+
The word can be accessed through ``result.word``, the similarity through ``result.similarity``.
669+
"""
670+
word: str = field(compare=False)
671+
similarity: float

tests/conftest.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@ def tests_root():
1616

1717
@pytest.fixture
1818
def simple_vocab_fifu(tests_root):
19-
yield finalfusion.vocab.load_vocab(tests_root / "data/simple_vocab.fifu")
19+
yield finalfusion.vocab.load_vocab(tests_root / "data" /
20+
"simple_vocab.fifu")
21+
22+
23+
@pytest.fixture
24+
def analogy_fifu(tests_root):
25+
yield finalfusion.load_finalfusion(tests_root / "data" /
26+
"simple_vocab.fifu")
2027

2128

2229
@pytest.fixture
@@ -45,23 +52,27 @@ def bucket_vocab_embeddings_fifu(tests_root):
4552

4653
@pytest.fixture
4754
def embeddings_text(tests_root):
48-
yield finalfusion.compat.load_text(
49-
os.path.join(tests_root, "data/embeddings.txt"))
55+
yield finalfusion.compat.load_text(tests_root / "data" / "embeddings.txt")
5056

5157

5258
@pytest.fixture
5359
def embeddings_text_dims(tests_root):
54-
yield finalfusion.compat.load_text_dims(
55-
os.path.join(tests_root, "data/embeddings.dims.txt"))
60+
yield finalfusion.compat.load_text_dims(tests_root / "data" /
61+
"embeddings.dims.txt")
5662

5763

5864
@pytest.fixture
5965
def embeddings_w2v(tests_root):
60-
yield finalfusion.compat.load_word2vec(
61-
os.path.join(tests_root, "data/embeddings.w2v"))
66+
yield finalfusion.compat.load_word2vec(tests_root / "data" /
67+
"embeddings.w2v")
6268

6369

6470
@pytest.fixture
6571
def embeddings_ft(tests_root):
66-
yield finalfusion.compat.load_fasttext(
67-
os.path.join(tests_root, "data/fasttext.bin"))
72+
yield finalfusion.compat.load_fasttext(tests_root / "data" /
73+
"fasttext.bin")
74+
75+
76+
@pytest.fixture
77+
def similarity_fifu(tests_root):
78+
yield finalfusion.load_finalfusion(tests_root / "data" / "similarity.fifu")

tests/data/similarity.fifu

16.6 KB
Binary file not shown.

tests/test_analogies.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import pytest
2+
3+
ANALOGY_ORDER = [
4+
"Deutschland",
5+
"Westdeutschland",
6+
"Sachsen",
7+
"Mitteldeutschland",
8+
"Brandenburg",
9+
"Polen",
10+
"Norddeutschland",
11+
"Dänemark",
12+
"Schleswig-Holstein",
13+
"Österreich",
14+
"Bayern",
15+
"Thüringen",
16+
"Bundesrepublik",
17+
"Ostdeutschland",
18+
"Preußen",
19+
"Deutschen",
20+
"Hessen",
21+
"Potsdam",
22+
"Mecklenburg",
23+
"Niedersachsen",
24+
"Hamburg",
25+
"Süddeutschland",
26+
"Bremen",
27+
"Russland",
28+
"Deutschlands",
29+
"BRD",
30+
"Litauen",
31+
"Mecklenburg-Vorpommern",
32+
"DDR",
33+
"West-Berlin",
34+
"Saarland",
35+
"Lettland",
36+
"Hannover",
37+
"Rostock",
38+
"Sachsen-Anhalt",
39+
"Pommern",
40+
"Schweden",
41+
"Deutsche",
42+
"deutschen",
43+
"Westfalen",
44+
]
45+
46+
47+
def test_analogies(analogy_fifu):
48+
for idx, analogy in enumerate(
49+
analogy_fifu.analogy("Paris", "Frankreich", "Berlin", 40)):
50+
assert ANALOGY_ORDER[idx] == analogy.word
51+
52+
assert analogy_fifu.analogy("Paris", "Frankreich", "Paris", 1,
53+
{"Paris"})[0].word == "Frankreich"
54+
assert analogy_fifu.analogy("Paris", "Frankreich", "Paris",
55+
1)[0].word != "Frankreich"
56+
assert analogy_fifu.analogy("Frankreich", "Frankreich", "Frankreich", 1,
57+
set())[0].word == "Frankreich"
58+
assert analogy_fifu.analogy("Frankreich", "Frankreich", "Frankreich", 1,
59+
{"Frankreich"})[0].word != "Frankreich"
60+
61+
assert analogy_fifu.analogy("Paris", "OOV", "Paris", 1) is None

tests/test_similarity.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import pytest
2+
import numpy
3+
4+
SIMILARITY_ORDER_STUTTGART_10 = [
5+
"Karlsruhe",
6+
"Mannheim",
7+
"München",
8+
"Darmstadt",
9+
"Heidelberg",
10+
"Wiesbaden",
11+
"Kassel",
12+
"Düsseldorf",
13+
"Leipzig",
14+
"Berlin",
15+
]
16+
17+
SIMILARITY_ORDER = [
18+
"Potsdam",
19+
"Hamburg",
20+
"Leipzig",
21+
"Dresden",
22+
"München",
23+
"Düsseldorf",
24+
"Bonn",
25+
"Stuttgart",
26+
"Weimar",
27+
"Berlin-Charlottenburg",
28+
"Rostock",
29+
"Karlsruhe",
30+
"Chemnitz",
31+
"Breslau",
32+
"Wiesbaden",
33+
"Hannover",
34+
"Mannheim",
35+
"Kassel",
36+
"Köln",
37+
"Danzig",
38+
"Erfurt",
39+
"Dessau",
40+
"Bremen",
41+
"Charlottenburg",
42+
"Magdeburg",
43+
"Neuruppin",
44+
"Darmstadt",
45+
"Jena",
46+
"Wien",
47+
"Heidelberg",
48+
"Dortmund",
49+
"Stettin",
50+
"Schwerin",
51+
"Neubrandenburg",
52+
"Greifswald",
53+
"Göttingen",
54+
"Braunschweig",
55+
"Berliner",
56+
"Warschau",
57+
"Berlin-Spandau",
58+
]
59+
60+
61+
def test_similarity_berlin_40(similarity_fifu):
62+
for idx, sim in enumerate(similarity_fifu.word_similarity("Berlin", 40)):
63+
assert SIMILARITY_ORDER[idx] == sim.word
64+
65+
66+
def test_similarity_stuttgart_10(similarity_fifu):
67+
for idx, sim in enumerate(similarity_fifu.word_similarity("Stuttgart",
68+
10)):
69+
assert SIMILARITY_ORDER_STUTTGART_10[idx] == sim.word
70+
71+
72+
def test_embedding_similarity_stuttgart_10(similarity_fifu):
73+
stuttgart = similarity_fifu.embedding("Stuttgart")
74+
sims = similarity_fifu.embedding_similarity(stuttgart, k=10)
75+
assert sims[0].word == "Stuttgart"
76+
77+
for idx, sim in enumerate(sims[1:]):
78+
assert SIMILARITY_ORDER_STUTTGART_10[idx] == sim.word
79+
80+
for idx, sim in enumerate(
81+
similarity_fifu.embedding_similarity(stuttgart,
82+
skip={"Stuttgart"},
83+
k=10)):
84+
assert SIMILARITY_ORDER_STUTTGART_10[idx] == sim.word
85+
86+
87+
def test_embedding_similarity_incompatible_shapes(similarity_fifu):
88+
incompatible_embed = numpy.ones(1, dtype=numpy.float32)
89+
with pytest.raises(ValueError):
90+
similarity_fifu.embedding_similarity(incompatible_embed)

0 commit comments

Comments
 (0)