Skip to content

Commit ea1992f

Browse files
committed
Implement ExplicitVocab.
1 parent b723797 commit ea1992f

File tree

7 files changed

+184
-26
lines changed

7 files changed

+184
-26
lines changed

src/finalfusion/embeddings.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from finalfusion.metadata import Metadata
1212
from finalfusion.norms import Norms
1313
from finalfusion.storage import Storage, NdArray
14-
from finalfusion.vocab import Vocab, SimpleVocab, FinalfusionBucketVocab, FastTextVocab
14+
from finalfusion.vocab import Vocab, SimpleVocab, FinalfusionBucketVocab, FastTextVocab, \
15+
ExplicitVocab
1516

1617

1718
class Embeddings: # pylint: disable=too-many-instance-attributes
@@ -40,6 +41,7 @@ class Embeddings: # pylint: disable=too-many-instance-attributes
4041
* :class:`~finalfusion.vocab.simple_vocab.SimpleVocab`,
4142
* :class:`~finalfusion.vocab.subword.FinalfusionBucketVocab`
4243
* :class:`~finalfusion.vocab.subword.FastTextVocab`
44+
* :class:`~finalfusion.vocab.subword.ExplicitVocab`
4345
3. :class:`~finalfusion.metadata.Metadata`
4446
4. :class:`~finalfusion.norms.Norms`
4547
@@ -463,6 +465,8 @@ def load_finalfusion(file: Union[str, bytes, int, PathLike],
463465
vocab = FinalfusionBucketVocab.read_chunk(inf)
464466
elif chunk_id == ChunkIdentifier.FastTextSubwordVocab:
465467
vocab = FastTextVocab.read_chunk(inf)
468+
elif chunk_id == ChunkIdentifier.ExplicitSubwordVocab:
469+
vocab = ExplicitVocab.read_chunk(inf)
466470
else:
467471
raise FinalfusionFormatError(
468472
f'Expected vocab chunk, not {str(chunk_id)}')

src/finalfusion/subword/explicit_indexer.pyx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ cdef class ExplicitIndexer:
207207
f"\tmin_n={self.min_n},\n" \
208208
f"\tmax_n={self.max_n},\n" \
209209
"\tngrams=[...],\n" \
210-
"\tngram_index={{...}}\n" \
211-
")"
210+
"\tngram_index={{...}})"
212211

213212
__all__ = ['ExplicitIndexer']

src/finalfusion/vocab/__init__.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from finalfusion.io import ChunkIdentifier, find_chunk
88
from finalfusion.vocab.simple_vocab import SimpleVocab, load_simple_vocab
99
from finalfusion.vocab.subword import FinalfusionBucketVocab, load_finalfusion_bucket_vocab, \
10-
FastTextVocab, load_fasttext_vocab
10+
FastTextVocab, load_fasttext_vocab, ExplicitVocab, load_explicit_vocab
1111
from finalfusion.vocab.vocab import Vocab
1212

1313

@@ -17,14 +17,20 @@ def load_vocab(file: Union[str, bytes, int, PathLike]) -> Vocab:
1717
1818
Loads the first known vocabulary from a finalfusion file.
1919
20+
One of:
21+
* :class:`~finalfusion.vocab.simple_vocab.SimpleVocab`,
22+
* :class:`~finalfusion.vocab.subword.FinalfusionBucketVocab`
23+
* :class:`~finalfusion.vocab.subword.FastTextVocab`
24+
* :class:`~finalfusion.vocab.subword.ExplicitVocab`
25+
2026
Parameters
2127
----------
2228
file: str, bytes, int, PathLike
2329
Path to file containing a finalfusion vocab chunk.
2430
2531
Returns
2632
-------
27-
vocab : Union[SimpleVocab, FinalfusionBucketVocab]
33+
vocab : Vocab
2834
First vocabulary in the file.
2935
3036
Raises
@@ -46,16 +52,13 @@ def load_vocab(file: Union[str, bytes, int, PathLike]) -> Vocab:
4652
return FinalfusionBucketVocab.read_chunk(inf)
4753
if chunk == ChunkIdentifier.FastTextSubwordVocab:
4854
return FastTextVocab.read_chunk(inf)
49-
raise NotImplementedError('Vocab type is not yet supported.')
55+
if chunk == ChunkIdentifier.ExplicitSubwordVocab:
56+
return ExplicitVocab.read_chunk(inf)
57+
raise ValueError(f'Unexpected chunk type {chunk}.')
5058

5159

5260
__all__ = [
53-
'Vocab',
54-
'load_vocab',
55-
'SimpleVocab',
56-
'load_simple_vocab',
57-
'FinalfusionBucketVocab',
58-
'load_finalfusion_bucket_vocab',
59-
'FastTextVocab',
60-
'load_fasttext_vocab',
61+
'Vocab', 'load_vocab', 'SimpleVocab', 'load_simple_vocab',
62+
'FinalfusionBucketVocab', 'load_finalfusion_bucket_vocab', 'FastTextVocab',
63+
'load_fasttext_vocab', 'ExplicitVocab', 'load_explicit_vocab'
6164
]

src/finalfusion/vocab/subword.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from finalfusion.io import ChunkIdentifier, find_chunk, _write_binary, _read_required_binary
1111
from finalfusion.subword import ExplicitIndexer, FastTextIndexer, FinalfusionHashIndexer, ngrams
1212
from finalfusion.vocab.vocab import Vocab, _validate_items_and_create_index, \
13-
_calculate_binary_list_size, _write_words_binary, _read_items
13+
_calculate_binary_list_size, _write_words_binary, _read_items, _read_items_with_indices
1414

1515

1616
class SubwordVocab(Vocab):
@@ -133,8 +133,7 @@ def __repr__(self) -> str:
133133
return f"{type(self).__name__}(\n" \
134134
f"\tindexer={self.subword_indexer}\n" \
135135
"\twords=[...]\n" \
136-
"\tword_index={{...}}\n" \
137-
")"
136+
"\tword_index={{...}})"
138137

139138
def __eq__(self, other: Any) -> bool:
140139
return isinstance(other, type(self)) and \
@@ -272,6 +271,84 @@ def chunk_identifier():
272271
return ChunkIdentifier.FastTextSubwordVocab
273272

274273

274+
class ExplicitVocab(SubwordVocab):
275+
"""
276+
A vocabulary with explicitly stored n-grams.
277+
"""
278+
def __init__(self, words: List[str], indexer: ExplicitIndexer):
279+
"""
280+
Initialize an ExplicitVocab.
281+
282+
Initializes the vocabulary with the given words and ExplicitIndexer.
283+
284+
The word list cannot contain duplicate entries.
285+
286+
Parameters
287+
----------
288+
words : List[str]
289+
List of unique words
290+
indexer : ExplicitIndexer
291+
Subword indexer to use for the vocabulary.
292+
293+
Raises
294+
------
295+
AssertionError
296+
If the indexer is not an ExplicitIndexer.
297+
298+
See Also
299+
--------
300+
:class:`.ExplicitIndexer`
301+
"""
302+
assert isinstance(indexer, ExplicitIndexer)
303+
super().__init__()
304+
self._index = _validate_items_and_create_index(words)
305+
self._words = words
306+
self._indexer = indexer
307+
308+
@property
309+
def word_index(self) -> dict:
310+
return self._index
311+
312+
@property
313+
def subword_indexer(self) -> ExplicitIndexer:
314+
return self._indexer
315+
316+
@property
317+
def words(self) -> list:
318+
return self._words
319+
320+
@staticmethod
321+
def chunk_identifier():
322+
return ChunkIdentifier.ExplicitSubwordVocab
323+
324+
@staticmethod
325+
def read_chunk(file: BinaryIO) -> 'ExplicitVocab':
326+
length, ngram_length, min_n, max_n = _read_required_binary(
327+
file, "<QQII")
328+
words = _read_items(file, length)
329+
ngram_list, ngram_index = _read_items_with_indices(file, ngram_length)
330+
indexer = ExplicitIndexer(ngram_list, min_n, max_n, ngram_index)
331+
return ExplicitVocab(words, indexer)
332+
333+
def write_chunk(self, file) -> None:
334+
chunk_length = _calculate_binary_list_size(self.words)
335+
chunk_length += _calculate_binary_list_size(
336+
self.subword_indexer.ngrams)
337+
min_n_max_n_size = struct.calcsize("<II")
338+
chunk_length += min_n_max_n_size
339+
chunk_header = (int(self.chunk_identifier()), chunk_length,
340+
len(self.words), len(self.subword_indexer.ngrams),
341+
self.min_n, self.max_n)
342+
_write_binary(file, "<IQQQII", *chunk_header)
343+
_write_words_binary((bytes(word, "utf-8") for word in self.words),
344+
file)
345+
for ngram in self.subword_indexer.ngrams:
346+
b_ngram = ngram.encode("utf-8")
347+
_write_binary(file, "<I", len(b_ngram))
348+
file.write(b_ngram)
349+
_write_binary(file, "<Q", self.subword_indexer.ngram_index[ngram])
350+
351+
275352
def load_finalfusion_bucket_vocab(file: Union[str, bytes, int, PathLike]
276353
) -> FinalfusionBucketVocab:
277354
"""
@@ -316,6 +393,28 @@ def load_fasttext_vocab(file: Union[str, bytes, int, PathLike]
316393
return FastTextVocab.read_chunk(inf)
317394

318395

396+
def load_explicit_vocab(file: Union[str, bytes, int, PathLike]
397+
) -> ExplicitVocab:
398+
"""
399+
Load a ExplicitVocab from the given finalfusion file.
400+
401+
Parameters
402+
----------
403+
file : str, bytes, int, PathLike
404+
Path to file containing a ExplicitVocab chunk.
405+
406+
Returns
407+
-------
408+
vocab : ExplicitVocab
409+
Returns the first ExplicitVocab in the file.
410+
"""
411+
with open(file, "rb") as inf:
412+
chunk = find_chunk(inf, [ChunkIdentifier.ExplicitSubwordVocab])
413+
if chunk is None:
414+
raise ValueError('File did not contain a FastTextVocab}')
415+
return ExplicitVocab.read_chunk(inf)
416+
417+
319418
def _write_bucket_vocab(file: BinaryIO,
320419
vocab: Union[FastTextVocab, FinalfusionBucketVocab]):
321420
min_n_max_n_size = struct.calcsize("<II")
@@ -339,5 +438,6 @@ def _write_bucket_vocab(file: BinaryIO,
339438

340439
__all__ = [
341440
'SubwordVocab', 'FinalfusionBucketVocab', 'load_finalfusion_bucket_vocab',
342-
'FastTextVocab', 'load_fasttext_vocab'
441+
'FastTextVocab', 'load_fasttext_vocab', 'ExplicitVocab',
442+
'load_explicit_vocab'
343443
]

src/finalfusion/vocab/vocab.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,34 @@ def _read_items(file: BinaryIO, length: int) -> List[str]:
131131
return items
132132

133133

134+
def _read_items_with_indices(file: BinaryIO,
135+
length: int) -> Tuple[List[str], Dict[str, int]]:
136+
"""
137+
Helper method to read items from a vocabulary chunk.
138+
139+
Parameters
140+
----------
141+
file : BinaryIO
142+
input file
143+
length : int
144+
number of items to read
145+
146+
Returns
147+
-------
148+
words : List[str]
149+
The word list
150+
"""
151+
items = []
152+
index = dict()
153+
for _ in range(length):
154+
item_length = _read_required_binary(file, "<I")[0]
155+
item = file.read(item_length).decode("utf-8")
156+
idx = _read_required_binary(file, "<Q")[0]
157+
items.append(item)
158+
index[item] = idx
159+
return items, index
160+
161+
134162
def _calculate_binary_list_size(items: List[str]):
135163
size = sum(len(bytes(item, "utf-8")) for item in items)
136164
size += struct.calcsize("<Q")

tests/test_subwords.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ def test_explicit():
8181
"\tmin_n=3,\n" \
8282
"\tmax_n=6,\n" \
8383
"\tngrams=[...],\n" \
84-
"\tngram_index={{...}}\n" \
85-
")"
84+
"\tngram_index={{...}})"
8685
assert indexer["0"] == 0
8786
assert indexer.ngrams[0] == "0"
8887
assert indexer("0") == 0

tests/test_vocab.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import finalfusion.vocab
33

44
from finalfusion.io import FinalfusionFormatError
5-
from finalfusion.subword import FinalfusionHashIndexer, FastTextIndexer
6-
from finalfusion.vocab import FinalfusionBucketVocab, SimpleVocab, load_vocab, FastTextVocab
5+
from finalfusion.subword import FinalfusionHashIndexer, FastTextIndexer, ExplicitIndexer
6+
from finalfusion.vocab import FinalfusionBucketVocab, SimpleVocab, load_vocab, FastTextVocab, ExplicitVocab
77

88

99
def test_reading(tests_root):
@@ -86,8 +86,7 @@ def test_fifu_buckets_constructor():
8686
assert repr(v) == f"FinalfusionBucketVocab(\n" \
8787
f"\tindexer={repr(v.subword_indexer)}\n" \
8888
"\twords=[...]\n" \
89-
"\tword_index={{...}}\n" \
90-
")"
89+
"\tword_index={{...}})"
9190

9291

9392
def test_fasttext_constructor():
@@ -106,8 +105,7 @@ def test_fasttext_constructor():
106105
assert repr(v) == f"FastTextVocab(\n" \
107106
f"\tindexer={repr(v.subword_indexer)}\n" \
108107
"\twords=[...]\n" \
109-
"\tword_index={{...}}\n" \
110-
")"
108+
"\tword_index={{...}})"
111109

112110

113111
def test_fasttext_vocab_roundtrip(tmp_path):
@@ -118,6 +116,33 @@ def test_fasttext_vocab_roundtrip(tmp_path):
118116
assert v == v2
119117

120118

119+
def test_explicit_constructor():
120+
i = ExplicitIndexer([str(i) for i in range(10)])
121+
v = ExplicitVocab([str(i) for i in range(10, 100)], indexer=i)
122+
assert [v[str(i)] for i in range(10, 100)] == [i for i in range(90)]
123+
with pytest.raises(AssertionError):
124+
_ = ExplicitVocab(v.words, FinalfusionHashIndexer(21))
125+
assert len(v) == 90
126+
assert v.upper_bound == len(v) + 10
127+
assert v == v
128+
assert v in v
129+
assert v != SimpleVocab(v.words)
130+
assert v != FastTextVocab(v.words, FastTextIndexer(20))
131+
assert repr(v) == f"ExplicitVocab(\n" \
132+
f"\tindexer={repr(v.subword_indexer)}\n" \
133+
"\twords=[...]\n" \
134+
"\tword_index={{...}})"
135+
136+
137+
def test_explicit_vocab_roundtrip(tmp_path):
138+
filename = tmp_path / "write_explicit_vocab.fifu"
139+
i = ExplicitIndexer([str(i) for i in range(10)])
140+
v = ExplicitVocab([str(i) for i in range(10, 100)], indexer=i)
141+
v.write(filename)
142+
v2 = load_vocab(filename)
143+
assert v == v2
144+
145+
121146
def test_fifu_buckets_roundtrip(tests_root, tmp_path):
122147
filename = tmp_path / "write_ff_buckets.fifu"
123148
v = load_vocab(tests_root / "data" / "ff_buckets.fifu")

0 commit comments

Comments
 (0)