|
| 1 | +""" |
| 2 | +Fasttext IO compat module. |
| 3 | +""" |
| 4 | + |
| 5 | +import sys |
| 6 | +from os import PathLike |
| 7 | +from typing import Union, BinaryIO, cast, List, Any, Dict |
| 8 | + |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +from finalfusion import Embeddings |
| 12 | +from finalfusion._util import _normalize_ndarray_storage |
| 13 | +from finalfusion.io import _read_required_binary, _write_binary, _serialize_array_as_le |
| 14 | +from finalfusion.metadata import Metadata |
| 15 | +from finalfusion.storage import NdArray |
| 16 | +from finalfusion.subword import FastTextIndexer |
| 17 | +from finalfusion.vocab import FastTextVocab, Vocab, SimpleVocab |
| 18 | + |
| 19 | +_FT_MAGIC = 793_712_314 |
| 20 | + |
| 21 | + |
| 22 | +def load_fasttext(file: Union[str, bytes, int, PathLike]) -> Embeddings: |
| 23 | + """ |
| 24 | + Read embeddings from a file in fastText format. |
| 25 | +
|
| 26 | + The returned embeddings have a FastTextVocab, NdArray storage and a Norms chunk. |
| 27 | +
|
| 28 | + Loading embeddings with this method will precompute embeddings for each word by averaging all |
| 29 | + of its subword embeddings together with the distinct word vector. Additionally, all precomputed |
| 30 | + vectors are l2-normalized and the corresponding norms are stored in the Norms. The subword |
| 31 | + embeddings are **not** l2-normalized. |
| 32 | +
|
| 33 | + Parameters |
| 34 | + ---------- |
| 35 | + file : str, bytes, int, PathLike |
| 36 | + Path to a file with embeddings in word2vec binary format. |
| 37 | +
|
| 38 | + Returns |
| 39 | + ------- |
| 40 | + embeddings : Embeddings |
| 41 | + The embeddings from the input file. |
| 42 | + """ |
| 43 | + with open(file, 'rb') as inf: |
| 44 | + _read_ft_header(inf) |
| 45 | + metadata = _read_ft_cfg(inf) |
| 46 | + vocab = _read_ft_vocab(inf, metadata['buckets'], metadata['min_n'], |
| 47 | + metadata['max_n']) |
| 48 | + storage = _read_ft_storage(inf, vocab) |
| 49 | + norms = _normalize_ndarray_storage(storage[:len(vocab)]) |
| 50 | + return Embeddings(storage=storage, |
| 51 | + vocab=vocab, |
| 52 | + norms=norms, |
| 53 | + metadata=metadata) |
| 54 | + |
| 55 | + |
| 56 | +def write_fasttext(file: Union[str, bytes, int, PathLike], embeds: Embeddings): |
| 57 | + """ |
| 58 | + Write embeddings in fastText format. |
| 59 | +
|
| 60 | + fastText requires Metadata with all expected keys for fastText configs: |
| 61 | + * dims: int (inferred from model) |
| 62 | + * window_size: int (default -1) |
| 63 | + * min_count: int (default -1) |
| 64 | + * ns: int (default -1) |
| 65 | + * word_ngrams: int (default 1) |
| 66 | + * loss: one of ``['HierarchicalSoftmax', 'NegativeSampling', 'Softmax']`` (default Softmax) |
| 67 | + * model: one of ``['CBOW', 'SkipGram', 'Supervised']`` (default SkipGram) |
| 68 | + * buckets: int (inferred from model) |
| 69 | + * min_n: int (inferred from model) |
| 70 | + * max_n: int (inferred from model) |
| 71 | + * lr_update_rate: int (default -1) |
| 72 | + * sampling_threshold: float (default -1) |
| 73 | +
|
| 74 | + ``dims``, ``buckets``, ``min_n`` and ``max_n`` are inferred from the model. If other values |
| 75 | + are unspecified, a default value of ``-1`` is used for all numerical fields. Loss defaults |
| 76 | + to ``Softmax``, model to ``SkipGram``. Unknown values for ``loss`` and ``model`` are |
| 77 | + overwritten with defaults since the models are incompatible with fastText otherwise. |
| 78 | +
|
| 79 | + Some information from original fastText models gets lost e.g.: |
| 80 | + * word frequencies |
| 81 | + * n_tokens |
| 82 | +
|
| 83 | + Embeddings are un-normalized before serialization: if norms are present, each embedding is |
| 84 | + scaled by the associated norm. Additionally, the original state of the embedding matrix is |
| 85 | + restored, precomputation and l2-normalization of word embeddings is undone. |
| 86 | +
|
| 87 | + Only embeddings with a FastTextVocab or SimpleVocab can be serialized to this format. |
| 88 | +
|
| 89 | + Parameters |
| 90 | + ---------- |
| 91 | + file : str, bytes, int, PathLike |
| 92 | + Output file |
| 93 | + embeds : Embeddings |
| 94 | + Embeddings to write |
| 95 | + """ |
| 96 | + with open(file, 'wb') as outf: |
| 97 | + if not isinstance(embeds.vocab, (FastTextVocab, SimpleVocab)): |
| 98 | + raise ValueError( |
| 99 | + f'Expected FastTextVocab or SimpleVocab, not: {type(embeds.vocab).__name__}' |
| 100 | + ) |
| 101 | + _write_binary(outf, "<ii", _FT_MAGIC, 12) |
| 102 | + _write_ft_cfg(outf, embeds) |
| 103 | + _write_ft_vocab(outf, embeds.vocab) |
| 104 | + _write_binary(outf, "<?QQ", 0, *embeds.storage.shape) |
| 105 | + if isinstance(embeds.vocab, SimpleVocab): |
| 106 | + _write_ft_storage_simple(outf, embeds) |
| 107 | + else: |
| 108 | + _write_ft_storage_subwords(outf, embeds) |
| 109 | + _serialize_array_as_le(outf, embeds.storage) |
| 110 | + |
| 111 | + |
| 112 | +def _read_ft_header(file: BinaryIO): |
| 113 | + """ |
| 114 | + Helper method to verify version and magic. |
| 115 | + """ |
| 116 | + magic, version = _read_required_binary(file, "<ii") |
| 117 | + if magic != _FT_MAGIC: |
| 118 | + raise ValueError(f"Magic should be 793_712_314, not: {magic}") |
| 119 | + if version != 12: |
| 120 | + raise ValueError(f"Expected version 12, not: {version}") |
| 121 | + |
| 122 | + |
| 123 | +def _read_ft_cfg(file: BinaryIO) -> Metadata: |
| 124 | + """ |
| 125 | + Constructs metadata from fastText config. |
| 126 | + """ |
| 127 | + cfg = list(_read_required_binary(file, "<12id")) |
| 128 | + losses = ['HierarchicalSoftmax', 'NegativeSampling', 'Softmax'] |
| 129 | + cfg[6] = losses[cfg[6] - 1] |
| 130 | + models = ['CBOW', 'SkipGram', 'Supervised'] |
| 131 | + cfg[7] = models[cfg[7] - 1] |
| 132 | + return Metadata(dict(zip(_FT_REQUIRED_CFG_KEYS, cfg))) |
| 133 | + |
| 134 | + |
| 135 | +def _read_ft_vocab(file: BinaryIO, buckets: int, min_n: int, |
| 136 | + max_n: int) -> Union[FastTextVocab, SimpleVocab]: |
| 137 | + """ |
| 138 | + Helper method to read a vocab from a fastText file |
| 139 | +
|
| 140 | + Returns a SimpleVocab if min_n is 0, otherwise FastTextVocab is returned. |
| 141 | + """ |
| 142 | + # discard n_words |
| 143 | + vocab_size, _n_words, n_labels = _read_required_binary(file, "<iii") |
| 144 | + if n_labels: |
| 145 | + raise NotImplementedError( |
| 146 | + "fastText prediction models are not supported") |
| 147 | + # discard n_tokens |
| 148 | + _read_required_binary(file, "<q") |
| 149 | + |
| 150 | + prune_idx_size = _read_required_binary(file, "<q")[0] |
| 151 | + if prune_idx_size > 0: |
| 152 | + raise NotImplementedError("Pruned vocabs are not supported") |
| 153 | + |
| 154 | + if min_n: |
| 155 | + return _read_ft_subwordvocab(file, buckets, min_n, max_n, vocab_size) |
| 156 | + return SimpleVocab([_read_binary_word(file) for _ in range(vocab_size)]) |
| 157 | + |
| 158 | + |
| 159 | +def _read_ft_subwordvocab(file: BinaryIO, buckets: int, min_n: int, max_n: int, |
| 160 | + vocab_size: int) -> FastTextVocab: |
| 161 | + """ |
| 162 | + Helper method to build a FastTextVocab from a fastText file. |
| 163 | + """ |
| 164 | + words = [_read_binary_word(file) for _ in range(vocab_size)] |
| 165 | + indexer = FastTextIndexer(buckets, min_n, max_n) |
| 166 | + return FastTextVocab(words, indexer) |
| 167 | + |
| 168 | + |
| 169 | +def _read_binary_word(file: BinaryIO) -> str: |
| 170 | + """ |
| 171 | + Helper method to read null-terminated binary strings. |
| 172 | + """ |
| 173 | + word = bytearray() |
| 174 | + while True: |
| 175 | + byte = file.read(1) |
| 176 | + if byte == b'\x00': |
| 177 | + break |
| 178 | + if byte == b'': |
| 179 | + raise EOFError |
| 180 | + word.extend(byte) |
| 181 | + # discard frequency |
| 182 | + _ = _read_required_binary(file, "<q") |
| 183 | + entry_type = _read_required_binary(file, "b")[0] |
| 184 | + if entry_type != 0: |
| 185 | + raise ValueError(f'Non word entry: {word}') |
| 186 | + |
| 187 | + # pylint: disable=fixme # XXX handle unicode errors |
| 188 | + return word.decode("utf8") |
| 189 | + |
| 190 | + |
| 191 | +def _read_ft_storage(file: BinaryIO, vocab: Vocab) -> NdArray: |
| 192 | + """ |
| 193 | + Helper method to read fastText storage. |
| 194 | +
|
| 195 | + If vocab is a SimpleVocab, the matrix is read and returned as is. |
| 196 | + If vocab is a FastTextVocab, the word representations are precomputed based |
| 197 | + on the vocab. |
| 198 | + """ |
| 199 | + quantized = _read_required_binary(file, "?")[0] |
| 200 | + if quantized: |
| 201 | + raise NotImplementedError( |
| 202 | + "Quantized storage is not supported for fastText models") |
| 203 | + rows, cols = _read_required_binary(file, "<qq") |
| 204 | + matrix = np.fromfile(file=file, count=rows * cols, |
| 205 | + dtype=np.float32).reshape((rows, cols)) |
| 206 | + if sys.byteorder == 'big': |
| 207 | + matrix.byteswap(inplace=True) |
| 208 | + if isinstance(vocab, FastTextVocab): |
| 209 | + _precompute_word_vecs(vocab, matrix) |
| 210 | + return NdArray(matrix) |
| 211 | + |
| 212 | + |
| 213 | +def _precompute_word_vecs(vocab: FastTextVocab, matrix: np.ndarray): |
| 214 | + """ |
| 215 | + Helper method to precompute word vectors. |
| 216 | +
|
| 217 | + Averages the distinct word representation and the corresponding ngram |
| 218 | + embeddings. |
| 219 | + """ |
| 220 | + for i, word in enumerate(vocab): |
| 221 | + indices = [i] |
| 222 | + if isinstance(vocab, FastTextVocab): |
| 223 | + subword_indices = cast( |
| 224 | + List[int], vocab.subword_indices(word, with_ngrams=False)) |
| 225 | + indices += subword_indices |
| 226 | + matrix[i] = matrix[indices].mean(0, keepdims=False) |
| 227 | + |
| 228 | + |
| 229 | +def _write_ft_cfg(file: BinaryIO, embeds: Embeddings): |
| 230 | + """ |
| 231 | + Helper method to write fastText config. |
| 232 | +
|
| 233 | + * dims: taken from embeds |
| 234 | + * window_size: -1 if unspecified |
| 235 | + * min_count: -1 if unspecified |
| 236 | + * ns: -1 if unspecified |
| 237 | + * word_ngrams: 1 |
| 238 | + * loss: one of `['HierarchicalSoftmax', 'NegativeSampling', 'Softmax']`, defaults to 'Softmax' |
| 239 | + * model: one of `['CBOW', 'SkipGram', 'Supervised']`, defaults to SkipGram |
| 240 | + * buckets: taken from embeds, 0 if SimpleVocab |
| 241 | + * min_n: taken from embeds, 0 if SimpleVocab |
| 242 | + * max_n: taken from embeds, 0 if SimpleVocab |
| 243 | + * lr_update_rate: -1 if unspecified |
| 244 | + * sampling_threshold: -1 if unspecified |
| 245 | +
|
| 246 | + loss and model values are overwritten by the default if they are not listed above. |
| 247 | + """ |
| 248 | + # declare some dummy values that we can't get from embeds |
| 249 | + meta = { |
| 250 | + 'window_size': -1, |
| 251 | + 'epoch': -1, |
| 252 | + 'min_count': -1, |
| 253 | + 'ns': -1, |
| 254 | + 'word_ngrams': 1, |
| 255 | + 'loss': 'Softmax', |
| 256 | + # fastText uses an integral enum with vals 1, 2, 3, so we can't use |
| 257 | + # a placeholder for unknown models which maps to e.g. 0. |
| 258 | + 'model': 'SkipGram', |
| 259 | + 'lr_update_rate': -1, |
| 260 | + 'sampling_threshold': -1 |
| 261 | + } # type: Dict[str, Any] |
| 262 | + if embeds.metadata is not None: |
| 263 | + meta.update(embeds.metadata) |
| 264 | + meta['dims'] = embeds.storage.shape[1] |
| 265 | + if isinstance(embeds.vocab, FastTextVocab): |
| 266 | + meta['min_n'] = embeds.vocab.min_n |
| 267 | + meta['max_n'] = embeds.vocab.max_n |
| 268 | + meta['buckets'] = embeds.vocab.subword_indexer.n_buckets |
| 269 | + else: |
| 270 | + meta['min_n'] = 0 |
| 271 | + meta['max_n'] = 0 |
| 272 | + meta['buckets'] = 0 |
| 273 | + cfg = [meta[k] for k in _FT_REQUIRED_CFG_KEYS] |
| 274 | + # see explanation above why we need to select some known value |
| 275 | + losses = {'HierarchicalSoftmax': 1, 'NegativeSampling': 2, 'Softmax': 3} |
| 276 | + cfg[6] = losses.get(cfg[6], 3) |
| 277 | + models = {'CBOW': 1, 'SkipGram': 2, 'Supervised': 3} |
| 278 | + cfg[7] = models.get(cfg[7], 2) |
| 279 | + _write_binary(file, "<12id", *cfg) |
| 280 | + |
| 281 | + |
| 282 | +def _write_ft_vocab(outf: BinaryIO, vocab: Vocab): |
| 283 | + """ |
| 284 | + Helper method to write a vocab to fastText. |
| 285 | + """ |
| 286 | + # assumes that vocab_size == word_size if n_labels == 0 |
| 287 | + _write_binary(outf, "<iii", len(vocab), len(vocab), 0) |
| 288 | + # we discard n_tokens, serialize as 0, no pruned vocabs exist, also 0 |
| 289 | + _write_binary(outf, "<qq", 0, 0) |
| 290 | + for word in vocab: |
| 291 | + outf.write(word.encode("utf-8")) |
| 292 | + outf.write(b'\x00') |
| 293 | + # we don't store frequency, also set to 0 |
| 294 | + _write_binary(outf, "<q", 0) |
| 295 | + # all entries are words = 0 |
| 296 | + _write_binary(outf, "b", 0) |
| 297 | + |
| 298 | + |
| 299 | +def _write_ft_storage_subwords(outf: BinaryIO, embeds: Embeddings): |
| 300 | + """ |
| 301 | + Helper method to write a storage with subwords. |
| 302 | +
|
| 303 | + Restores the original embedding format of fastText, i.e. precomputation is |
| 304 | + undone and unnormalizes the embeddings. |
| 305 | + """ |
| 306 | + vocab = embeds.vocab |
| 307 | + assert isinstance(vocab, FastTextVocab) |
| 308 | + storage = embeds.storage |
| 309 | + norms = embeds.norms |
| 310 | + for i, word in enumerate(vocab): |
| 311 | + indices = vocab.subword_indices(word) |
| 312 | + embed = storage[i] * (len(indices) + 1) |
| 313 | + if norms is not None: |
| 314 | + embed *= norms[i] |
| 315 | + embed -= storage[indices].sum(0, keepdims=False) |
| 316 | + _serialize_array_as_le(outf, embed) |
| 317 | + |
| 318 | + _serialize_array_as_le(outf, storage[len(vocab):]) |
| 319 | + |
| 320 | + |
| 321 | +def _write_ft_storage_simple(outf: BinaryIO, embeds: Embeddings): |
| 322 | + """ |
| 323 | + Helper method to write storage of a simple vocab model. |
| 324 | +
|
| 325 | + Unnormalizes embeddings. |
| 326 | + """ |
| 327 | + storage = embeds.storage |
| 328 | + norms = embeds.norms |
| 329 | + for i in range(storage.shape[0]): |
| 330 | + embed = storage[i] |
| 331 | + if norms is not None: |
| 332 | + embed = norms[i] * embed |
| 333 | + _serialize_array_as_le(outf, embed) |
| 334 | + |
| 335 | + |
| 336 | +_FT_REQUIRED_CFG_KEYS = [ |
| 337 | + 'dims', 'window_size', 'epoch', 'min_count', 'ns', 'word_ngrams', 'loss', |
| 338 | + 'model', 'buckets', 'min_n', 'max_n', 'lr_update_rate', |
| 339 | + 'sampling_threshold' |
| 340 | +] |
| 341 | + |
| 342 | +__all__ = ['load_fasttext', 'write_fasttext'] |
0 commit comments