-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwmdistance.py
More file actions
136 lines (114 loc) · 5.02 KB
/
wmdistance.py
File metadata and controls
136 lines (114 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import division # py3 "true division"
from itertools import chain
import logging
from numbers import Integral
try:
from queue import Queue, Empty
except ImportError:
from Queue import Queue, Empty # noqa:F401
from numpy import dot, float32 as REAL, memmap as np_memmap, \
double, array, zeros, vstack, sqrt, newaxis, integer, \
ndarray, sum as np_sum, prod, argmax
import numpy as np
from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
from gensim.corpora.dictionary import Dictionary
from six import string_types, integer_types
from six.moves import zip, range
from scipy import stats
from gensim.utils import deprecated
from gensim.models.utils_any2vec import (
_save_word2vec_format,
_load_word2vec_format,
ft_ngram_hashes,
)
from gensim.similarities.termsim import TermSimilarityIndex, SparseTermSimilarityMatrix
#
# For backwards compatibility, see https://github.com/RaRe-Technologies/gensim/issues/2201
#
from gensim.models.deprecated.keyedvectors import EuclideanKeyedVectors # noqa
logger = logging.getLogger(__name__)
#def wmdistance(model, document1, document2):
def wmdistance(dict1, dict2):
document1=dict1.keys()
document2=dict2.keys()
"""Compute the Word Mover's Distance between two documents.
When using this code, please consider citing the following papers:
* `Ofir Pele and Michael Werman "A linear time histogram metric for improved SIFT matching"
<http://www.cs.huji.ac.il/~werman/Papers/ECCV2008.pdf>`_
* `Ofir Pele and Michael Werman "Fast and robust earth mover's distances"
<https://ieeexplore.ieee.org/document/5459199/>`_
* `Matt Kusner et al. "From Word Embeddings To Document Distances"
<http://proceedings.mlr.press/v37/kusnerb15.pdf>`_.
Parameters
----------
document1 : list of str
Input document.
document2 : list of str
Input document.
Returns
-------
float
Word Mover's distance between `document1` and `document2`.
Warnings
--------
This method only works if `pyemd <https://pypi.org/project/pyemd/>`_ is installed.
If one of the documents have no words that exist in the vocab, `float('inf')` (i.e. infinity)
will be returned.
Raises
------
ImportError
If `pyemd <https://pypi.org/project/pyemd/>`_ isn't installed.
"""
# If pyemd C extension is available, import it.
# If pyemd is attempted to be used, but isn't installed, ImportError will be raised in wmdistance
from pyemd import emd
# Remove out-of-vocabulary words.
len_pre_oov1 = len(document1)
len_pre_oov2 = len(document2)
#document1 = [token for token in document1 if token in model]
#document2 = [token for token in document2 if token in model]
diff1 = len_pre_oov1 - len(document1)
diff2 = len_pre_oov2 - len(document2)
if diff1 > 0 or diff2 > 0:
logger.info('Removed %d and %d OOV words from document 1 and 2 (respectively).', diff1, diff2)
if not document1 or not document2:
logger.info(
"At least one of the documents had no words that were in the vocabulary. "
"Aborting (returning inf)."
)
return float('inf')
dictionary = Dictionary(documents=[document1, document2])
vocab_len = len(dictionary)
if vocab_len == 1:
# Both documents are composed by a single unique token
return 0.0
# Sets for faster look-up.
docset1 = set(document1)
docset2 = set(document2)
# Compute distance matrix.
distance_matrix = zeros((vocab_len, vocab_len), dtype=double)
for i, t1 in dictionary.items():
if t1 not in docset1:
continue
for j, t2 in dictionary.items():
if t2 not in docset2 or distance_matrix[i, j] != 0.0:
continue
# Compute Euclidean distance between word vectors.
#distance_matrix[i, j] = distance_matrix[j, i] = sqrt(np_sum((model[t1] - model[t2])**2))
distance_matrix[i, j] = distance_matrix[j, i] = sqrt(np_sum((dict1[t1] - dict2[t2])**2))
if np_sum(distance_matrix) == 0.0:
# `emd` gets stuck if the distance matrix contains only zeros.
logger.info('The distance matrix is all zeros. Aborting (returning inf).')
return float('inf')
def nbow(document):
d = zeros(vocab_len, dtype=double)
nbow = dictionary.doc2bow(document) # Word frequencies.
doc_len = len(document)
for idx, freq in nbow:
d[idx] = freq / float(doc_len) # Normalized word frequencies.
return d
# Compute nBOW representation of documents.
d1 = nbow(document1)
d2 = nbow(document2)
# Compute WMD.
return emd(d1, d2, distance_matrix)