Skip to content

Commit 006cf92

Browse files
committed
Bucket to explicit conversion.
Implement conversion of bucket vocabularies to explicit ones and conversion of embeddings with bucket to explicit vocabs.
1 parent ea1992f commit 006cf92

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

src/finalfusion/embeddings.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,43 @@ def write(self, file: str):
381381
for chunk in chunks:
382382
chunk.write_chunk(outf)
383383

384+
def bucket_to_explicit(self) -> 'Embeddings':
385+
"""
386+
Bucket to explicit Embeddings conversion.
387+
388+
Multiple embeddings can still map to the same bucket, but all buckets that are not
389+
indexed by in-vocabulary n-grams are eliminated. This can have a big impact on the
390+
size of the embedding matrix.
391+
392+
Metadata is **not** copied to the new embeddings since it doesn't reflect the
393+
changes. You can manually set the metadata and update the values accordingly.
394+
395+
Returns
396+
-------
397+
embeddings : Embeddings
398+
Embeddings with an ExplicitVocab instead of a hash-based vocabulary.
399+
400+
Raises
401+
------
402+
TypeError
403+
If the current vocabulary is not a hash-based vocabulary
404+
(FinalfusionBucketVocab or FastTextVocab)
405+
"""
406+
bucket_vocabs = (FastTextVocab, FinalfusionBucketVocab)
407+
if not isinstance(self.vocab, bucket_vocabs):
408+
raise TypeError(
409+
"Only bucketed embeddings can be converted to explicit.")
410+
vocab = self.vocab.to_explicit()
411+
storage = np.zeros((vocab.upper_bound, self._storage.shape[1]),
412+
dtype=np.float32)
413+
storage[:len(vocab)] = self._storage[:len(vocab)]
414+
for ngram in vocab.subword_indexer:
415+
storage[len(vocab) + vocab.subword_indexer[ngram]] = self._storage[
416+
len(vocab) + self.vocab.subword_indexer(ngram)]
417+
return Embeddings(vocab=vocab,
418+
storage=NdArray(storage),
419+
norms=self.norms)
420+
384421
def __contains__(self, item):
385422
return item in self._vocab
386423

src/finalfusion/vocab/subword.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,25 @@ def __init__(self,
180180
self._words = words
181181
self._indexer = indexer
182182

183+
def to_explicit(self) -> 'ExplicitVocab':
184+
"""
185+
Return an ExplicitVocab built from this vocab.
186+
187+
This method iterates over the known words and extracts all ngrams within this vocab's
188+
bounds. Each of the ngrams is hashed and mapped to an index. This index is not necessarily
189+
unique for each ngram, if hashes collide, multiple ngrams will be mapped to the same index.
190+
191+
The returned vocab will be unable to produce indices for unknown ngrams.
192+
193+
The indices of the new vocabs known indices will be cover `[0, vocab.upper_bound)`
194+
195+
Returns
196+
-------
197+
explicit_vocab : ExplicitVocab
198+
The converted vocabulary.
199+
"""
200+
return _bucket_to_explicit(self)
201+
183202
def write_chunk(self, file: BinaryIO):
184203
_write_bucket_vocab(file, self)
185204

@@ -244,6 +263,25 @@ def __init__(self,
244263
self._words = words
245264
self._indexer = indexer
246265

266+
def to_explicit(self) -> 'ExplicitVocab':
267+
"""
268+
Return an ExplicitVocab built from this vocab.
269+
270+
This method iterates over the known words and extracts all ngrams within this vocab's
271+
bounds. Each of the ngrams is hashed and mapped to an index. This index is not necessarily
272+
unique for each ngram, if hashes collide, multiple ngrams will be mapped to the same index.
273+
274+
The returned vocab will be unable to produce indices for unknown ngrams.
275+
276+
The indices of the new vocabs known indices will be cover `[0, vocab.upper_bound)`
277+
278+
Returns
279+
-------
280+
explicit_vocab : ExplicitVocab
281+
The converted vocabulary.
282+
"""
283+
return _bucket_to_explicit(self)
284+
247285
@property
248286
def subword_indexer(self) -> FastTextIndexer:
249287
return self._indexer
@@ -415,6 +453,25 @@ def load_explicit_vocab(file: Union[str, bytes, int, PathLike]
415453
return ExplicitVocab.read_chunk(inf)
416454

417455

456+
def _bucket_to_explicit(vocab: Union[FinalfusionBucketVocab, FastTextVocab]
457+
) -> 'ExplicitVocab':
458+
ngram_index = dict()
459+
idx_index = dict() # type: Dict[int, int]
460+
ngram_list = []
461+
for word in vocab.words:
462+
token_ngrams = vocab.subwords(word)
463+
for ngram in token_ngrams:
464+
if ngram not in ngram_index:
465+
ngram_list.append(ngram)
466+
idx = vocab.subword_indexer(ngram)
467+
if idx not in idx_index:
468+
idx_index[idx] = len(idx_index)
469+
ngram_index[ngram] = idx_index[idx]
470+
indexer = ExplicitIndexer(ngram_list, vocab.min_n, vocab.max_n,
471+
ngram_index)
472+
return ExplicitVocab(vocab.words, indexer)
473+
474+
418475
def _write_bucket_vocab(file: BinaryIO,
419476
vocab: Union[FastTextVocab, FinalfusionBucketVocab]):
420477
min_n_max_n_size = struct.calcsize("<II")

tests/test_embeddings.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,44 @@ def test_no_norms(vocab_array_tuple):
192192
embeddings = Embeddings(vocab=SimpleVocab(vocab), storage=NdArray(matrix))
193193
with pytest.raises(TypeError):
194194
_ = embeddings.embedding_with_norm("bla")
195+
196+
197+
def test_buckets_to_explicit(bucket_vocab_embeddings_fifu):
198+
explicit = bucket_vocab_embeddings_fifu.bucket_to_explicit()
199+
assert bucket_vocab_embeddings_fifu.vocab.words == explicit.vocab.words
200+
for e1, e2 in zip(bucket_vocab_embeddings_fifu, explicit):
201+
assert e1[0] == e1[0]
202+
assert np.allclose(e1[1], e2[1])
203+
assert bucket_vocab_embeddings_fifu.vocab.upper_bound == 1024 + len(
204+
bucket_vocab_embeddings_fifu.vocab)
205+
assert explicit.vocab.upper_bound == len(
206+
bucket_vocab_embeddings_fifu.vocab) + 16
207+
known = len(bucket_vocab_embeddings_fifu.vocab)
208+
assert np.allclose(bucket_vocab_embeddings_fifu.storage[:known],
209+
explicit.storage[:known])
210+
bucket_indexer = bucket_vocab_embeddings_fifu.vocab.subword_indexer
211+
explicit_indexer = explicit.vocab.subword_indexer
212+
for ngram in explicit_indexer:
213+
assert np.allclose(
214+
bucket_vocab_embeddings_fifu.storage[2 + bucket_indexer(ngram)],
215+
explicit.storage[2 + explicit_indexer(ngram)])
216+
217+
218+
def test_buckets_to_explicit_roundtrip(bucket_vocab_embeddings_fifu, tmp_path):
219+
filename = tmp_path / "bucket_to_explicit_embeds.fifu"
220+
explicit = bucket_vocab_embeddings_fifu.bucket_to_explicit()
221+
explicit.write(filename)
222+
explicit2 = load_finalfusion(filename)
223+
assert explicit.vocab == explicit2.vocab
224+
assert np.allclose(explicit.storage, explicit2.storage)
225+
assert np.allclose(explicit.norms, explicit2.norms)
226+
assert np.allclose(bucket_vocab_embeddings_fifu.norms, explicit2.norms)
227+
known = len(bucket_vocab_embeddings_fifu.vocab)
228+
assert np.allclose(bucket_vocab_embeddings_fifu.storage[:known],
229+
explicit2.storage[:known])
230+
bucket_indexer = bucket_vocab_embeddings_fifu.vocab.subword_indexer
231+
explicit_indexer = explicit.vocab.subword_indexer
232+
for ngram in explicit_indexer:
233+
assert np.allclose(
234+
bucket_vocab_embeddings_fifu.storage[2 + bucket_indexer(ngram)],
235+
explicit.storage[2 + explicit_indexer(ngram)])

tests/test_vocab.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,18 @@ def test_explicit_vocab_roundtrip(tmp_path):
143143
assert v == v2
144144

145145

146+
def test_bucket_to_explicit():
147+
v = FinalfusionBucketVocab(["allerdings", "groß"])
148+
explicit = v.to_explicit()
149+
assert v.words == explicit.words
150+
assert explicit.upper_bound == len(v) + 43
151+
assert explicit.subword_indexer.upper_bound == 43
152+
assert explicit.subword_indexer("dings") == explicit.subword_indexer(
153+
"<gro")
154+
assert v.subword_indexer("dings") == v.subword_indexer("<gro")
155+
assert len(explicit.subword_indexer) == 44
156+
157+
146158
def test_fifu_buckets_roundtrip(tests_root, tmp_path):
147159
filename = tmp_path / "write_ff_buckets.fifu"
148160
v = load_vocab(tests_root / "data" / "ff_buckets.fifu")

0 commit comments

Comments
 (0)