diff --git a/torchnlp/metrics/__init__.py b/torchnlp/metrics/__init__.py index 87ccd0f..5db0192 100755 --- a/torchnlp/metrics/__init__.py +++ b/torchnlp/metrics/__init__.py @@ -1,7 +1,6 @@ from torchnlp.metrics.accuracy import get_accuracy from torchnlp.metrics.accuracy import get_token_accuracy from torchnlp.metrics.bleu import get_moses_multi_bleu - # TODO: Use `sklearn.metrics` for a `confusion_matrix` implemented with ignore_index # TODO: Use `sklearn.metrics` for a `recall` implemented with ignore_index # TODO: Use `sklearn.metrics` for a `precision` implemented with ignore_index diff --git a/torchnlp/word_to_vector/fast_text.py b/torchnlp/word_to_vector/fast_text.py index dd266e9..85c67ce 100644 --- a/torchnlp/word_to_vector/fast_text.py +++ b/torchnlp/word_to_vector/fast_text.py @@ -46,12 +46,16 @@ class FastText(_PretrainedWordVectors): * https://arxiv.org/abs/1710.04087 Args: - language (str): language of the vectors + name (str or None, optional): The name of the file that contains the vectors + url (str or None, optional): url for download if vectors not found in cache + language (str): language of the vectors (only needed when both url and name + are ignored) aligned (bool): if True: use multilingual embeddings where words with the same meaning share (approximately) the same position in the vector space across languages. if False: use regular FastText embeddings. All available languages can be found under - https://github.com/facebookresearch/MUSE#multilingual-word-embeddings + https://github.com/facebookresearch/MUSE#multilingual-word-embeddings. + (only needed when both url and name are ignored) cache (str, optional): directory for cached vectors unk_init (callback, optional): by default, initialize out-of-vocabulary word vectors to zero vectors; can be any function that takes in a Tensor and @@ -74,10 +78,12 @@ class FastText(_PretrainedWordVectors): url_base = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.{}.vec' aligned_url_base = 'https://dl.fbaipublicfiles.com/fasttext/vectors-aligned/wiki.{}.align.vec' - def __init__(self, language="en", aligned=False, **kwargs): - if aligned: - url = self.aligned_url_base.format(language) - else: - url = self.url_base.format(language) - name = os.path.basename(url) + def __init__(self, language="en", url=None, name=None, aligned=False, **kwargs): + if not name: + if not url: + if aligned: + url = self.aligned_url_base.format(language) + else: + url = self.url_base.format(language) + name = os.path.basename(url) super(FastText, self).__init__(name, url=url, **kwargs)