diff --git a/glove/glove.py b/glove/glove.py index ec90ca3..7aa0639 100644 --- a/glove/glove.py +++ b/glove/glove.py @@ -285,17 +285,7 @@ def most_similar(self, word, number=5): Run a similarity query, retrieving number most similar words. """ - - if self.word_vectors is None: - raise Exception('Model must be fit before querying') - - if self.dictionary is None: - raise Exception('No word dictionary supplied') - - try: - word_idx = self.dictionary[word] - except KeyError: - raise Exception('Word not in dictionary') + word_idx = self._index_by_word(word) return self._similarity_query(self.word_vectors[word_idx], number)[1:] @@ -307,3 +297,24 @@ def most_similar_paragraph(self, paragraph, number=5, **kwargs): paragraph_vector = self.transform_paragraph(paragraph, **kwargs) return self._similarity_query(paragraph_vector, number) + + def word_vector_by_word(self, word): + """ + Given a word returns its embedding vector representation + """ + word_idx = self._index_by_word(word) + + return self.word_vectors[word_idx] + + def _index_by_word(self, word): + if self.word_vectors is None: + raise Exception('Model must be fit before querying') + + if self.dictionary is None: + raise Exception('No word dictionary supplied') + + try: + return self.dictionary[word] + except KeyError: + raise Exception('Word not in dictionary') + diff --git a/readme.md b/readme.md index 304f72a..13b8d29 100644 --- a/readme.md +++ b/readme.md @@ -59,6 +59,13 @@ Out[19]: ('racing', 0.83157724991920212)] ``` +You can also get the word vector representation: + +``` +In [20]: glove.word_vector_by_word('car') +array([0.34795333, 0.63220108, 0.06546937, ... 0.66068305, 0.91771246, 0.01173065]) +``` + ## Development Pull requests are welcome. diff --git a/tests/test_glove.py b/tests/test_glove.py index 91442ca..cf35d80 100644 --- a/tests/test_glove.py +++ b/tests/test_glove.py @@ -5,6 +5,7 @@ from utils import generate_training_corpus +import pytest def _reproduce_input_matrix(glove_model): @@ -80,3 +81,50 @@ def test_fitting(): repr_matrix = _reproduce_input_matrix(glove_model) assert ((repr_matrix - log_cooc_mat) ** 2).sum() < 1500.0 + +def test_word_vector_by_word(): + glove_model = Glove() + glove_model.word_vectors = [ + [1, 2, 3], + [4, 5, 6] + ] + glove_model.dictionary = { + "first": 0, + "second": 1 + } + + result0 = glove_model.word_vector_by_word("first") + + assert(result0 == [1, 2, 3]) + + result1 = glove_model.word_vector_by_word("second") + + assert(result1 == [4, 5, 6]) + +def test_word_vector_by_word_without_fitting(): + glove_model = Glove() + glove_model.dictionary = {"word": 0} + + with pytest.raises(Exception) as ex: + glove_model.word_vector_by_word("word") + + assert(ex.value.message == "Model must be fit before querying") + +def test_word_vector_by_word_without_dictionary(): + glove_model = Glove() + glove_model.word_vectors = [[1, 2, 3]] + + with pytest.raises(Exception) as ex: + glove_model.word_vector_by_word("word") + + assert(ex.value.message == "No word dictionary supplied") + +def test_word_vector_by_word_without_word(): + glove_model = Glove() + glove_model.word_vectors = [[1, 2, 3]] + glove_model.dictionary = {"other": 0} + + with pytest.raises(Exception) as ex: + glove_model.word_vector_by_word("word") + + assert(ex.value.message == "Word not in dictionary")