Skip to content

Commit b723797

Browse files
committed
Implement FastTextVocab.
1 parent eec2213 commit b723797

File tree

4 files changed

+139
-14
lines changed

4 files changed

+139
-14
lines changed

src/finalfusion/embeddings.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from finalfusion.metadata import Metadata
1212
from finalfusion.norms import Norms
1313
from finalfusion.storage import Storage, NdArray
14-
from finalfusion.vocab import Vocab, SimpleVocab, FinalfusionBucketVocab
14+
from finalfusion.vocab import Vocab, SimpleVocab, FinalfusionBucketVocab, FastTextVocab
1515

1616

1717
class Embeddings: # pylint: disable=too-many-instance-attributes
@@ -38,7 +38,8 @@ class Embeddings: # pylint: disable=too-many-instance-attributes
3838
* :class:`~finalfusion.storage.ndarray.NdArray`
3939
2. :class:`~finalfusion.vocab.Vocab` *(required)*:
4040
* :class:`~finalfusion.vocab.simple_vocab.SimpleVocab`,
41-
:class:`~finalfusion.vocab.subword.FinalfusionBucketVocab`
41+
* :class:`~finalfusion.vocab.subword.FinalfusionBucketVocab`
42+
* :class:`~finalfusion.vocab.subword.FastTextVocab`
4243
3. :class:`~finalfusion.metadata.Metadata`
4344
4. :class:`~finalfusion.norms.Norms`
4445
@@ -460,6 +461,8 @@ def load_finalfusion(file: Union[str, bytes, int, PathLike],
460461
vocab = SimpleVocab.read_chunk(inf) # type: Vocab
461462
elif chunk_id == ChunkIdentifier.BucketSubwordVocab:
462463
vocab = FinalfusionBucketVocab.read_chunk(inf)
464+
elif chunk_id == ChunkIdentifier.FastTextSubwordVocab:
465+
vocab = FastTextVocab.read_chunk(inf)
463466
else:
464467
raise FinalfusionFormatError(
465468
f'Expected vocab chunk, not {str(chunk_id)}')

src/finalfusion/vocab/__init__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from finalfusion.io import ChunkIdentifier, find_chunk
88
from finalfusion.vocab.simple_vocab import SimpleVocab, load_simple_vocab
9-
from finalfusion.vocab.subword import FinalfusionBucketVocab, load_finalfusion_bucket_vocab
9+
from finalfusion.vocab.subword import FinalfusionBucketVocab, load_finalfusion_bucket_vocab, \
10+
FastTextVocab, load_fasttext_vocab
1011
from finalfusion.vocab.vocab import Vocab
1112

1213

@@ -43,10 +44,18 @@ def load_vocab(file: Union[str, bytes, int, PathLike]) -> Vocab:
4344
return SimpleVocab.read_chunk(inf)
4445
if chunk == ChunkIdentifier.BucketSubwordVocab:
4546
return FinalfusionBucketVocab.read_chunk(inf)
47+
if chunk == ChunkIdentifier.FastTextSubwordVocab:
48+
return FastTextVocab.read_chunk(inf)
4649
raise NotImplementedError('Vocab type is not yet supported.')
4750

4851

4952
__all__ = [
50-
'Vocab', 'load_vocab', 'SimpleVocab', 'load_simple_vocab',
51-
'FinalfusionBucketVocab', 'load_finalfusion_bucket_vocab'
53+
'Vocab',
54+
'load_vocab',
55+
'SimpleVocab',
56+
'load_simple_vocab',
57+
'FinalfusionBucketVocab',
58+
'load_finalfusion_bucket_vocab',
59+
'FastTextVocab',
60+
'load_fasttext_vocab',
5261
]

src/finalfusion/vocab/subword.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,19 +148,16 @@ class FinalfusionBucketVocab(SubwordVocab):
148148
"""
149149
def __init__(self,
150150
words: List[str],
151-
indexer: FinalfusionHashIndexer = None):
151+
indexer: Optional[FinalfusionHashIndexer] = None):
152152
"""
153153
Initialize a FinalfusionBucketVocab.
154154
155-
Initializes the vocabulary with the given words and optional index and
156-
indexer.
155+
Initializes the vocabulary with the given words.
157156
158157
If no indexer is passed, a FinalfusionHashIndexer with bucket exponent
159158
21 is used.
160159
161-
If no index is given, the nth word in the `words` list is assigned
162-
index `n`. The word list cannot contain duplicate entries and it needs
163-
to be of same length as the index.
160+
The word list cannot contain duplicate entries.
164161
165162
Parameters
166163
----------
@@ -211,6 +208,70 @@ def chunk_identifier() -> ChunkIdentifier:
211208
return ChunkIdentifier.BucketSubwordVocab
212209

213210

211+
class FastTextVocab(SubwordVocab):
212+
"""
213+
FastText vocabulary
214+
"""
215+
def __init__(self,
216+
words: List[str],
217+
indexer: Optional[FastTextIndexer] = None):
218+
"""
219+
Initialize a FastTextVocab.
220+
221+
Initializes the vocabulary with the given words.
222+
223+
If no indexer is passed, a FastTextIndexer with 2_000_000 buckets is used.
224+
225+
The word list cannot contain duplicate entries.
226+
227+
Parameters
228+
----------
229+
words : List[str]
230+
List of unique words
231+
indexer : FastTextIndexer, optional
232+
Subword indexer to use for the vocabulary. Defaults to an indexer
233+
with 2_000_000 buckets and range 3-6.
234+
235+
Raises
236+
------
237+
AssertionError
238+
If the indexer is not a FastTextIndexer or ``words`` contains duplicate entries.
239+
"""
240+
if indexer is None:
241+
indexer = FastTextIndexer(2000000)
242+
assert isinstance(indexer, FastTextIndexer)
243+
super().__init__()
244+
self._index = _validate_items_and_create_index(words)
245+
self._words = words
246+
self._indexer = indexer
247+
248+
@property
249+
def subword_indexer(self) -> FastTextIndexer:
250+
return self._indexer
251+
252+
@property
253+
def words(self) -> List[str]:
254+
return self._words
255+
256+
@property
257+
def word_index(self) -> Dict[str, int]:
258+
return self._index
259+
260+
@staticmethod
261+
def read_chunk(file: BinaryIO) -> 'FastTextVocab':
262+
length, min_n, max_n, buckets = _read_required_binary(file, "<QIII")
263+
words = _read_items(file, length)
264+
indexer = FastTextIndexer(buckets, min_n, max_n)
265+
return FastTextVocab(words, indexer)
266+
267+
def write_chunk(self, file: BinaryIO):
268+
_write_bucket_vocab(file, self)
269+
270+
@staticmethod
271+
def chunk_identifier():
272+
return ChunkIdentifier.FastTextSubwordVocab
273+
274+
214275
def load_finalfusion_bucket_vocab(file: Union[str, bytes, int, PathLike]
215276
) -> FinalfusionBucketVocab:
216277
"""
@@ -233,7 +294,30 @@ def load_finalfusion_bucket_vocab(file: Union[str, bytes, int, PathLike]
233294
return FinalfusionBucketVocab.read_chunk(inf)
234295

235296

236-
def _write_bucket_vocab(file: BinaryIO, vocab: FinalfusionBucketVocab):
297+
def load_fasttext_vocab(file: Union[str, bytes, int, PathLike]
298+
) -> FastTextVocab:
299+
"""
300+
Load a FastTextVocab from the given finalfusion file.
301+
302+
Parameters
303+
----------
304+
file : str, bytes, int, PathLike
305+
Path to file containing a FastTextVocab chunk.
306+
307+
Returns
308+
-------
309+
vocab : FastTextVocab
310+
Returns the first FastTextVocab in the file.
311+
"""
312+
with open(file, "rb") as inf:
313+
chunk = find_chunk(inf, [ChunkIdentifier.FastTextSubwordVocab])
314+
if chunk is None:
315+
raise ValueError('File did not contain a FastTextVocab}')
316+
return FastTextVocab.read_chunk(inf)
317+
318+
319+
def _write_bucket_vocab(file: BinaryIO,
320+
vocab: Union[FastTextVocab, FinalfusionBucketVocab]):
237321
min_n_max_n_size = struct.calcsize("<II")
238322
buckets_size = struct.calcsize("<I")
239323
chunk_length = _calculate_binary_list_size(vocab.words)
@@ -254,5 +338,6 @@ def _write_bucket_vocab(file: BinaryIO, vocab: FinalfusionBucketVocab):
254338

255339

256340
__all__ = [
257-
'SubwordVocab', 'FinalfusionBucketVocab', 'load_finalfusion_bucket_vocab'
341+
'SubwordVocab', 'FinalfusionBucketVocab', 'load_finalfusion_bucket_vocab',
342+
'FastTextVocab', 'load_fasttext_vocab'
258343
]

tests/test_vocab.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from finalfusion.io import FinalfusionFormatError
55
from finalfusion.subword import FinalfusionHashIndexer, FastTextIndexer
6-
from finalfusion.vocab import FinalfusionBucketVocab, SimpleVocab, load_vocab
6+
from finalfusion.vocab import FinalfusionBucketVocab, SimpleVocab, load_vocab, FastTextVocab
77

88

99
def test_reading(tests_root):
@@ -90,6 +90,34 @@ def test_fifu_buckets_constructor():
9090
")"
9191

9292

93+
def test_fasttext_constructor():
94+
v = FastTextVocab([str(i) for i in range(10)])
95+
assert [v[str(i)] for i in range(10)] == [i for i in range(10)]
96+
with pytest.raises(AssertionError):
97+
v = FastTextVocab(["a"] * 2)
98+
with pytest.raises(AssertionError):
99+
_ = FastTextVocab(v.words, FinalfusionHashIndexer(21))
100+
assert len(v) == 10
101+
assert v.upper_bound == len(v) + 2_000_000
102+
assert v == v
103+
assert v in v
104+
assert v != SimpleVocab(v.words)
105+
assert v != FastTextVocab(v.words, FastTextIndexer(20))
106+
assert repr(v) == f"FastTextVocab(\n" \
107+
f"\tindexer={repr(v.subword_indexer)}\n" \
108+
"\twords=[...]\n" \
109+
"\tword_index={{...}}\n" \
110+
")"
111+
112+
113+
def test_fasttext_vocab_roundtrip(tmp_path):
114+
filename = tmp_path / "write_ft_vocab.fifu"
115+
v = FastTextVocab([str(i) for i in range(10)])
116+
v.write(filename)
117+
v2 = load_vocab(filename)
118+
assert v == v2
119+
120+
93121
def test_fifu_buckets_roundtrip(tests_root, tmp_path):
94122
filename = tmp_path / "write_ff_buckets.fifu"
95123
v = load_vocab(tests_root / "data" / "ff_buckets.fifu")

0 commit comments

Comments
 (0)