11"""
22Finalfusion Embeddings
33"""
4+ import heapq
5+ from dataclasses import field , dataclass
46from os import PathLike
5- from typing import Optional , Tuple , List , Union , Any , Iterator
7+ from typing import Optional , Tuple , List , Union , Any , Iterator , Set
68
79import numpy as np
810
@@ -418,6 +420,120 @@ def bucket_to_explicit(self) -> 'Embeddings':
418420 storage = NdArray (storage ),
419421 norms = self .norms )
420422
423+ def analogy ( # pylint: disable=too-many-arguments
424+ self ,
425+ word1 : str ,
426+ word2 : str ,
427+ word3 : str ,
428+ k : int = 1 ,
429+ skip : Set [str ] = None ) -> Optional [List ['SimilarityResult' ]]:
430+ """
431+ Perform an analogy query.
432+
433+ This method returns words that are close in vector space the analogy
434+ query `word1` is to `word2` as `word3` is to `?`. More concretely,
435+ it searches embeddings that are similar to:
436+
437+ *embedding(word2) - embedding(word1) + embedding(word3)*
438+
439+ Words specified in ``skip`` are not considered as answers. If ``skip``
440+ is None, the query words ``word1``, ``word2`` and ``word3`` are
441+ excluded.
442+
443+ At most, ``k`` results are returned. ``None`` is returned when no
444+ embedding could be computed for any of the tokens.
445+
446+ Parameters
447+ ----------
448+ word1 : str
449+ Word1 is to...
450+ word2 : str
451+ word2 like...
452+ word3 : str
453+ word3 is to the return value
454+ skip : Set[str]
455+ Set of strings which should not be considered as answers. Defaults
456+ to ``None`` which excludes the query strings. To allow the query
457+ strings as answers, pass an empty set.
458+ k : int
459+ Number of answers to return, defaults to 1.
460+
461+ Returns
462+ -------
463+ answers : List[SimilarityResult]
464+ List of answers.
465+ """
466+ embed_a = self .embedding (word1 )
467+ embed_b = self .embedding (word2 )
468+ embed_c = self .embedding (word3 )
469+ if embed_a is None or embed_b is None or embed_c is None :
470+ return None
471+ diff = embed_b - embed_a
472+ embed_d = embed_c + diff
473+ embed_d /= np .linalg .norm (embed_d )
474+ return self ._similarity (
475+ embed_d , k , {word1 , word2 , word3 } if skip is None else skip )
476+
477+ def word_similarity (self , query : str ,
478+ k : int = 10 ) -> Optional [List ['SimilarityResult' ]]:
479+ """
480+ Retrieves the nearest neighbors of the query string.
481+
482+ The similarity between the embedding of the query and other embeddings
483+ is defined by the dot product of the embeddings. If the vectors are
484+ unit vectors, this is the cosine similarity.
485+
486+ At most, ``k`` results are returned.
487+
488+ Parameters
489+ ----------
490+ query : str
491+ The query string
492+ k : int
493+ The number of neighbors to return, defaults to 10.
494+
495+ Returns
496+ -------
497+ neighbours : List[Tuple[str, float], optional
498+ List of tuples with neighbour and similarity measure. None if no
499+ embedding can be found for ``query``.
500+ """
501+ embed = self .embedding (query )
502+ if embed is None :
503+ return None
504+ return self ._similarity (embed , k , {query })
505+
506+ def embedding_similarity (self ,
507+ query : np .ndarray ,
508+ k : int = 10 ,
509+ skip : Optional [Set [str ]] = None
510+ ) -> Optional [List ['SimilarityResult' ]]:
511+ """
512+ Retrieves the nearest neighbors of the query embedding.
513+
514+ The similarity between the query embedding and other embeddings is
515+ defined by the dot product of the embeddings. If the vectors are unit
516+ vectors, this is the cosine similarity.
517+
518+ At most, ``k`` results are returned.
519+
520+ Parameters
521+ ----------
522+ query : str
523+ The query array.
524+ k : int
525+ The number of neighbors to return, defaults to 10.
526+ skip : Set[str], optional
527+ Set of strings that should not be considered as neighbours.
528+
529+ Returns
530+ -------
531+ neighbours : List[Tuple[str, float], optional
532+ List of tuples with neighbour and similarity measure. None if no
533+ embedding can be found for ``query``.
534+ """
535+ return self ._similarity (query , k , set () if skip is None else skip )
536+
421537 def __contains__ (self , item ):
422538 return item in self ._vocab
423539
@@ -427,6 +543,24 @@ def __iter__(self) -> Union[Iterator[Tuple[str, np.ndarray]], Iterator[
427543 return zip (self ._vocab , self ._storage , self ._norms )
428544 return zip (self ._vocab , self ._storage )
429545
546+ def _similarity (self , query : np .ndarray , k : int ,
547+ skips : Set [str ]) -> List ['SimilarityResult' ]:
548+ words = self .storage [:len (self .vocab )] # type: np.ndarray
549+ sims = words .dot (query )
550+ skip_indices = set (skip for skip in (self .vocab .word_index .get (skip )
551+ for skip in skips )
552+ if skip is not None )
553+ partition = sims .argpartition (- k -
554+ len (skip_indices ))[- k -
555+ len (skip_indices ):]
556+
557+ heap = [] # type: List[SimilarityResult]
558+ for idx in partition :
559+ if idx not in skip_indices :
560+ heapq .heappush (
561+ heap , SimilarityResult (self .vocab .words [idx ], sims [idx ]))
562+ return heapq .nlargest (k , heap )
563+
430564 def _embedding (self ,
431565 idx : Union [int , List [int ]],
432566 out : Optional [np .ndarray ] = None
@@ -524,3 +658,14 @@ def load_finalfusion(file: Union[str, bytes, int, PathLike],
524658 f'Expected norms chunk, not { str (chunk_id )} ' )
525659
526660 return Embeddings (storage , vocab , norms , metadata )
661+
662+
663+ @dataclass (order = True )
664+ class SimilarityResult :
665+ """
666+ Container for a Similarity result.
667+
668+ The word can be accessed through ``result.word``, the similarity through ``result.similarity``.
669+ """
670+ word : str = field (compare = False )
671+ similarity : float
0 commit comments