Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions glove/glove.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,17 +285,7 @@ def most_similar(self, word, number=5):
Run a similarity query, retrieving number
most similar words.
"""

if self.word_vectors is None:
raise Exception('Model must be fit before querying')

if self.dictionary is None:
raise Exception('No word dictionary supplied')

try:
word_idx = self.dictionary[word]
except KeyError:
raise Exception('Word not in dictionary')
word_idx = self._index_by_word(word)

return self._similarity_query(self.word_vectors[word_idx], number)[1:]

Expand All @@ -307,3 +297,24 @@ def most_similar_paragraph(self, paragraph, number=5, **kwargs):
paragraph_vector = self.transform_paragraph(paragraph, **kwargs)

return self._similarity_query(paragraph_vector, number)

def word_vector_by_word(self, word):
"""
Given a word returns its embedding vector representation
"""
word_idx = self._index_by_word(word)

return self.word_vectors[word_idx]

def _index_by_word(self, word):
if self.word_vectors is None:
raise Exception('Model must be fit before querying')

if self.dictionary is None:
raise Exception('No word dictionary supplied')

try:
return self.dictionary[word]
except KeyError:
raise Exception('Word not in dictionary')

7 changes: 7 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ Out[19]:
('racing', 0.83157724991920212)]
```

You can also get the word vector representation:

```
In [20]: glove.word_vector_by_word('car')
array([0.34795333, 0.63220108, 0.06546937, ... 0.66068305, 0.91771246, 0.01173065])
```

## Development
Pull requests are welcome.

Expand Down
48 changes: 48 additions & 0 deletions tests/test_glove.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from utils import generate_training_corpus

import pytest

def _reproduce_input_matrix(glove_model):

Expand Down Expand Up @@ -80,3 +81,50 @@ def test_fitting():
repr_matrix = _reproduce_input_matrix(glove_model)

assert ((repr_matrix - log_cooc_mat) ** 2).sum() < 1500.0

def test_word_vector_by_word():
glove_model = Glove()
glove_model.word_vectors = [
[1, 2, 3],
[4, 5, 6]
]
glove_model.dictionary = {
"first": 0,
"second": 1
}

result0 = glove_model.word_vector_by_word("first")

assert(result0 == [1, 2, 3])

result1 = glove_model.word_vector_by_word("second")

assert(result1 == [4, 5, 6])

def test_word_vector_by_word_without_fitting():
glove_model = Glove()
glove_model.dictionary = {"word": 0}

with pytest.raises(Exception) as ex:
glove_model.word_vector_by_word("word")

assert(ex.value.message == "Model must be fit before querying")

def test_word_vector_by_word_without_dictionary():
glove_model = Glove()
glove_model.word_vectors = [[1, 2, 3]]

with pytest.raises(Exception) as ex:
glove_model.word_vector_by_word("word")

assert(ex.value.message == "No word dictionary supplied")

def test_word_vector_by_word_without_word():
glove_model = Glove()
glove_model.word_vectors = [[1, 2, 3]]
glove_model.dictionary = {"other": 0}

with pytest.raises(Exception) as ex:
glove_model.word_vector_by_word("word")

assert(ex.value.message == "Word not in dictionary")