Skip to content

Commit 91776c0

Browse files
committed
Implement compat module for fastText.
Implement fastText reader and writer.
1 parent c1009a2 commit 91776c0

File tree

5 files changed

+374
-5
lines changed

5 files changed

+374
-5
lines changed

src/finalfusion/compat/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
This module contains read and write methods for other common embedding formats such as:
55
* text(-dims)
66
* word2vec binary
7+
* fastText
78
"""
8-
9+
from finalfusion.compat.fasttext import write_fasttext, load_fasttext
910
from finalfusion.compat.text import load_text, load_text_dims, write_text, write_text_dims
1011
from finalfusion.compat.word2vec import load_word2vec, write_word2vec
1112

1213
__all__ = [
1314
'load_text_dims', 'load_word2vec', 'load_text', 'write_word2vec',
14-
'write_text', 'write_text_dims'
15+
'write_text', 'write_text_dims', 'load_fasttext', 'write_fasttext'
1516
]

src/finalfusion/compat/fasttext.py

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
"""
2+
Fasttext IO compat module.
3+
"""
4+
5+
import sys
6+
from os import PathLike
7+
from typing import Union, BinaryIO, cast, List, Any, Dict
8+
9+
import numpy as np
10+
11+
from finalfusion import Embeddings
12+
from finalfusion._util import _normalize_ndarray_storage
13+
from finalfusion.io import _read_required_binary, _write_binary, _serialize_array_as_le
14+
from finalfusion.metadata import Metadata
15+
from finalfusion.storage import NdArray
16+
from finalfusion.subword import FastTextIndexer
17+
from finalfusion.vocab import FastTextVocab, Vocab, SimpleVocab
18+
19+
_FT_MAGIC = 793_712_314
20+
21+
22+
def load_fasttext(file: Union[str, bytes, int, PathLike]) -> Embeddings:
23+
"""
24+
Read embeddings from a file in fastText format.
25+
26+
The returned embeddings have a FastTextVocab, NdArray storage and a Norms chunk.
27+
28+
Loading embeddings with this method will precompute embeddings for each word by averaging all
29+
of its subword embeddings together with the distinct word vector. Additionally, all precomputed
30+
vectors are l2-normalized and the corresponding norms are stored in the Norms. The subword
31+
embeddings are **not** l2-normalized.
32+
33+
Parameters
34+
----------
35+
file : str, bytes, int, PathLike
36+
Path to a file with embeddings in word2vec binary format.
37+
38+
Returns
39+
-------
40+
embeddings : Embeddings
41+
The embeddings from the input file.
42+
"""
43+
with open(file, 'rb') as inf:
44+
_read_ft_header(inf)
45+
metadata = _read_ft_cfg(inf)
46+
vocab = _read_ft_vocab(inf, metadata['buckets'], metadata['min_n'],
47+
metadata['max_n'])
48+
storage = _read_ft_storage(inf, vocab)
49+
norms = _normalize_ndarray_storage(storage[:len(vocab)])
50+
return Embeddings(storage=storage,
51+
vocab=vocab,
52+
norms=norms,
53+
metadata=metadata)
54+
55+
56+
def write_fasttext(file: Union[str, bytes, int, PathLike], embeds: Embeddings):
57+
"""
58+
Write embeddings in fastText format.
59+
60+
fastText requires Metadata with all expected keys for fastText configs:
61+
* dims: int (inferred from model)
62+
* window_size: int (default -1)
63+
* min_count: int (default -1)
64+
* ns: int (default -1)
65+
* word_ngrams: int (default 1)
66+
* loss: one of ``['HierarchicalSoftmax', 'NegativeSampling', 'Softmax']`` (default Softmax)
67+
* model: one of ``['CBOW', 'SkipGram', 'Supervised']`` (default SkipGram)
68+
* buckets: int (inferred from model)
69+
* min_n: int (inferred from model)
70+
* max_n: int (inferred from model)
71+
* lr_update_rate: int (default -1)
72+
* sampling_threshold: float (default -1)
73+
74+
``dims``, ``buckets``, ``min_n`` and ``max_n`` are inferred from the model. If other values
75+
are unspecified, a default value of ``-1`` is used for all numerical fields. Loss defaults
76+
to ``Softmax``, model to ``SkipGram``. Unknown values for ``loss`` and ``model`` are
77+
overwritten with defaults since the models are incompatible with fastText otherwise.
78+
79+
Some information from original fastText models gets lost e.g.:
80+
* word frequencies
81+
* n_tokens
82+
83+
Embeddings are un-normalized before serialization: if norms are present, each embedding is
84+
scaled by the associated norm. Additionally, the original state of the embedding matrix is
85+
restored, precomputation and l2-normalization of word embeddings is undone.
86+
87+
Only embeddings with a FastTextVocab or SimpleVocab can be serialized to this format.
88+
89+
Parameters
90+
----------
91+
file : str, bytes, int, PathLike
92+
Output file
93+
embeds : Embeddings
94+
Embeddings to write
95+
"""
96+
with open(file, 'wb') as outf:
97+
if not isinstance(embeds.vocab, (FastTextVocab, SimpleVocab)):
98+
raise ValueError(
99+
f'Expected FastTextVocab or SimpleVocab, not: {type(embeds.vocab).__name__}'
100+
)
101+
_write_binary(outf, "<ii", _FT_MAGIC, 12)
102+
_write_ft_cfg(outf, embeds)
103+
_write_ft_vocab(outf, embeds.vocab)
104+
_write_binary(outf, "<?QQ", 0, *embeds.storage.shape)
105+
if isinstance(embeds.vocab, SimpleVocab):
106+
_write_ft_storage_simple(outf, embeds)
107+
else:
108+
_write_ft_storage_subwords(outf, embeds)
109+
_serialize_array_as_le(outf, embeds.storage)
110+
111+
112+
def _read_ft_header(file: BinaryIO):
113+
"""
114+
Helper method to verify version and magic.
115+
"""
116+
magic, version = _read_required_binary(file, "<ii")
117+
if magic != _FT_MAGIC:
118+
raise ValueError(f"Magic should be 793_712_314, not: {magic}")
119+
if version != 12:
120+
raise ValueError(f"Expected version 12, not: {version}")
121+
122+
123+
def _read_ft_cfg(file: BinaryIO) -> Metadata:
124+
"""
125+
Constructs metadata from fastText config.
126+
"""
127+
cfg = list(_read_required_binary(file, "<12id"))
128+
losses = ['HierarchicalSoftmax', 'NegativeSampling', 'Softmax']
129+
cfg[6] = losses[cfg[6] - 1]
130+
models = ['CBOW', 'SkipGram', 'Supervised']
131+
cfg[7] = models[cfg[7] - 1]
132+
return Metadata(dict(zip(_FT_REQUIRED_CFG_KEYS, cfg)))
133+
134+
135+
def _read_ft_vocab(file: BinaryIO, buckets: int, min_n: int,
136+
max_n: int) -> Union[FastTextVocab, SimpleVocab]:
137+
"""
138+
Helper method to read a vocab from a fastText file
139+
140+
Returns a SimpleVocab if min_n is 0, otherwise FastTextVocab is returned.
141+
"""
142+
# discard n_words
143+
vocab_size, _n_words, n_labels = _read_required_binary(file, "<iii")
144+
if n_labels:
145+
raise NotImplementedError(
146+
"fastText prediction models are not supported")
147+
# discard n_tokens
148+
_read_required_binary(file, "<q")
149+
150+
prune_idx_size = _read_required_binary(file, "<q")[0]
151+
if prune_idx_size > 0:
152+
raise NotImplementedError("Pruned vocabs are not supported")
153+
154+
if min_n:
155+
return _read_ft_subwordvocab(file, buckets, min_n, max_n, vocab_size)
156+
return SimpleVocab([_read_binary_word(file) for _ in range(vocab_size)])
157+
158+
159+
def _read_ft_subwordvocab(file: BinaryIO, buckets: int, min_n: int, max_n: int,
160+
vocab_size: int) -> FastTextVocab:
161+
"""
162+
Helper method to build a FastTextVocab from a fastText file.
163+
"""
164+
words = [_read_binary_word(file) for _ in range(vocab_size)]
165+
indexer = FastTextIndexer(buckets, min_n, max_n)
166+
return FastTextVocab(words, indexer)
167+
168+
169+
def _read_binary_word(file: BinaryIO) -> str:
170+
"""
171+
Helper method to read null-terminated binary strings.
172+
"""
173+
word = bytearray()
174+
while True:
175+
byte = file.read(1)
176+
if byte == b'\x00':
177+
break
178+
if byte == b'':
179+
raise EOFError
180+
word.extend(byte)
181+
# discard frequency
182+
_ = _read_required_binary(file, "<q")
183+
entry_type = _read_required_binary(file, "b")[0]
184+
if entry_type != 0:
185+
raise ValueError(f'Non word entry: {word}')
186+
187+
# pylint: disable=fixme # XXX handle unicode errors
188+
return word.decode("utf8")
189+
190+
191+
def _read_ft_storage(file: BinaryIO, vocab: Vocab) -> NdArray:
192+
"""
193+
Helper method to read fastText storage.
194+
195+
If vocab is a SimpleVocab, the matrix is read and returned as is.
196+
If vocab is a FastTextVocab, the word representations are precomputed based
197+
on the vocab.
198+
"""
199+
quantized = _read_required_binary(file, "?")[0]
200+
if quantized:
201+
raise NotImplementedError(
202+
"Quantized storage is not supported for fastText models")
203+
rows, cols = _read_required_binary(file, "<qq")
204+
matrix = np.fromfile(file=file, count=rows * cols,
205+
dtype=np.float32).reshape((rows, cols))
206+
if sys.byteorder == 'big':
207+
matrix.byteswap(inplace=True)
208+
if isinstance(vocab, FastTextVocab):
209+
_precompute_word_vecs(vocab, matrix)
210+
return NdArray(matrix)
211+
212+
213+
def _precompute_word_vecs(vocab: FastTextVocab, matrix: np.ndarray):
214+
"""
215+
Helper method to precompute word vectors.
216+
217+
Averages the distinct word representation and the corresponding ngram
218+
embeddings.
219+
"""
220+
for i, word in enumerate(vocab):
221+
indices = [i]
222+
if isinstance(vocab, FastTextVocab):
223+
subword_indices = cast(
224+
List[int], vocab.subword_indices(word, with_ngrams=False))
225+
indices += subword_indices
226+
matrix[i] = matrix[indices].mean(0, keepdims=False)
227+
228+
229+
def _write_ft_cfg(file: BinaryIO, embeds: Embeddings):
230+
"""
231+
Helper method to write fastText config.
232+
233+
* dims: taken from embeds
234+
* window_size: -1 if unspecified
235+
* min_count: -1 if unspecified
236+
* ns: -1 if unspecified
237+
* word_ngrams: 1
238+
* loss: one of `['HierarchicalSoftmax', 'NegativeSampling', 'Softmax']`, defaults to 'Softmax'
239+
* model: one of `['CBOW', 'SkipGram', 'Supervised']`, defaults to SkipGram
240+
* buckets: taken from embeds, 0 if SimpleVocab
241+
* min_n: taken from embeds, 0 if SimpleVocab
242+
* max_n: taken from embeds, 0 if SimpleVocab
243+
* lr_update_rate: -1 if unspecified
244+
* sampling_threshold: -1 if unspecified
245+
246+
loss and model values are overwritten by the default if they are not listed above.
247+
"""
248+
# declare some dummy values that we can't get from embeds
249+
meta = {
250+
'window_size': -1,
251+
'epoch': -1,
252+
'min_count': -1,
253+
'ns': -1,
254+
'word_ngrams': 1,
255+
'loss': 'Softmax',
256+
# fastText uses an integral enum with vals 1, 2, 3, so we can't use
257+
# a placeholder for unknown models which maps to e.g. 0.
258+
'model': 'SkipGram',
259+
'lr_update_rate': -1,
260+
'sampling_threshold': -1
261+
} # type: Dict[str, Any]
262+
if embeds.metadata is not None:
263+
meta.update(embeds.metadata)
264+
meta['dims'] = embeds.storage.shape[1]
265+
if isinstance(embeds.vocab, FastTextVocab):
266+
meta['min_n'] = embeds.vocab.min_n
267+
meta['max_n'] = embeds.vocab.max_n
268+
meta['buckets'] = embeds.vocab.subword_indexer.n_buckets
269+
else:
270+
meta['min_n'] = 0
271+
meta['max_n'] = 0
272+
meta['buckets'] = 0
273+
cfg = [meta[k] for k in _FT_REQUIRED_CFG_KEYS]
274+
# see explanation above why we need to select some known value
275+
losses = {'HierarchicalSoftmax': 1, 'NegativeSampling': 2, 'Softmax': 3}
276+
cfg[6] = losses.get(cfg[6], 3)
277+
models = {'CBOW': 1, 'SkipGram': 2, 'Supervised': 3}
278+
cfg[7] = models.get(cfg[7], 2)
279+
_write_binary(file, "<12id", *cfg)
280+
281+
282+
def _write_ft_vocab(outf: BinaryIO, vocab: Vocab):
283+
"""
284+
Helper method to write a vocab to fastText.
285+
"""
286+
# assumes that vocab_size == word_size if n_labels == 0
287+
_write_binary(outf, "<iii", len(vocab), len(vocab), 0)
288+
# we discard n_tokens, serialize as 0, no pruned vocabs exist, also 0
289+
_write_binary(outf, "<qq", 0, 0)
290+
for word in vocab:
291+
outf.write(word.encode("utf-8"))
292+
outf.write(b'\x00')
293+
# we don't store frequency, also set to 0
294+
_write_binary(outf, "<q", 0)
295+
# all entries are words = 0
296+
_write_binary(outf, "b", 0)
297+
298+
299+
def _write_ft_storage_subwords(outf: BinaryIO, embeds: Embeddings):
300+
"""
301+
Helper method to write a storage with subwords.
302+
303+
Restores the original embedding format of fastText, i.e. precomputation is
304+
undone and unnormalizes the embeddings.
305+
"""
306+
vocab = embeds.vocab
307+
assert isinstance(vocab, FastTextVocab)
308+
storage = embeds.storage
309+
norms = embeds.norms
310+
for i, word in enumerate(vocab):
311+
indices = vocab.subword_indices(word)
312+
embed = storage[i] * (len(indices) + 1)
313+
if norms is not None:
314+
embed *= norms[i]
315+
embed -= storage[indices].sum(0, keepdims=False)
316+
_serialize_array_as_le(outf, embed)
317+
318+
_serialize_array_as_le(outf, storage[len(vocab):])
319+
320+
321+
def _write_ft_storage_simple(outf: BinaryIO, embeds: Embeddings):
322+
"""
323+
Helper method to write storage of a simple vocab model.
324+
325+
Unnormalizes embeddings.
326+
"""
327+
storage = embeds.storage
328+
norms = embeds.norms
329+
for i in range(storage.shape[0]):
330+
embed = storage[i]
331+
if norms is not None:
332+
embed = norms[i] * embed
333+
_serialize_array_as_le(outf, embed)
334+
335+
336+
_FT_REQUIRED_CFG_KEYS = [
337+
'dims', 'window_size', 'epoch', 'min_count', 'ns', 'word_ngrams', 'loss',
338+
'model', 'buckets', 'min_n', 'max_n', 'lr_update_rate',
339+
'sampling_threshold'
340+
]
341+
342+
__all__ = ['load_fasttext', 'write_fasttext']

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,9 @@ def embeddings_text_dims(tests_root):
5959
def embeddings_w2v(tests_root):
6060
yield finalfusion.compat.load_word2vec(
6161
os.path.join(tests_root, "data/embeddings.w2v"))
62+
63+
64+
@pytest.fixture
65+
def embeddings_ft(tests_root):
66+
yield finalfusion.compat.load_fasttext(
67+
os.path.join(tests_root, "data/fasttext.bin"))

tests/data/fasttext.bin

83.1 KB
Binary file not shown.

0 commit comments

Comments
 (0)