Skip to content

Commit 667fe9e

Browse files
authored
Use LFUCache
1 parent 535a18b commit 667fe9e

File tree

1 file changed

+100
-66
lines changed

1 file changed

+100
-66
lines changed

camel_tools/disambig/bert/unfactored.py

Lines changed: 100 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pathlib import Path
2828
import pickle
2929

30+
from cachetools import LFUCache
3031
import numpy as np
3132
import torch
3233
import torch.nn as nn
@@ -228,14 +229,18 @@ class BERTUnfactoredDisambiguator(Disambiguator):
228229
use_gpu (:obj:`bool`, optional): The flag to use a GPU or not.
229230
Defaults to True.
230231
batch_size (:obj:`int`, optional): The batch size. Defaults to 32.
231-
ranking_cache (:obj:`dict`, optional): The cache dictionary of
232+
ranking_cache (:obj:`LFUCache`, optional): The cache dictionary of
232233
pre-computed scored analyses. Defaults to `None`.
234+
ranking_cache_size (:obj:`int`, optional): The number of unique word
235+
disambiguations to cache. If 0, no ranked analyses will be cached.
236+
The cache uses a least-frequently-used eviction policy.
237+
Defaults to 100000.
233238
"""
234239

235240
def __init__(self, model_path, analyzer,
236241
features=FEATURE_SET_MAP['feats_14'], top=1,
237242
scorer='uniform', tie_breaker='tag', use_gpu=True,
238-
batch_size=32, ranking_cache=None):
243+
batch_size=32, ranking_cache=None, ranking_cache_size=100000):
239244
self._model = {
240245
'unfactored': _BERTFeatureTagger(model_path)
241246
}
@@ -246,12 +251,23 @@ def __init__(self, model_path, analyzer,
246251
self._tie_breaker = tie_breaker
247252
self._use_gpu = use_gpu
248253
self._batch_size = batch_size
249-
self._ranking_cache = ranking_cache
250254
self._mle = _read_json(f'{model_path}/mle_model.json')
251255

256+
if ranking_cache is None:
257+
if ranking_cache_size <= 0:
258+
self._ranking_cache = None
259+
self._disambiguate_word_fn = self._disambiguate_word
260+
else:
261+
self._ranking_cache = LFUCache(ranking_cache_size)
262+
self._disambiguate_word_fn = self._disambiguate_word_cached
263+
else:
264+
self._ranking_cache = ranking_cache
265+
self._disambiguate_word_fn = self._disambiguate_word_cached
266+
252267
@staticmethod
253268
def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
254-
cache_size=10000, pretrained_cache=True):
269+
cache_size=10000, pretrained_cache=True,
270+
ranking_cache_size=100000):
255271
"""Load a pre-trained model provided with camel_tools.
256272
257273
Args:
@@ -270,6 +286,10 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
270286
pretrained_cache (:obj:`bool`, optional): The flag to use a
271287
pretrained cache that stores ranked analyses.
272288
Defaults to True.
289+
ranking_cache_size (:obj:`int`, optional): The number of unique
290+
word disambiguations to cache. If 0, no ranked analyses will be
291+
cached. The cache uses a least-frequently-used eviction policy.
292+
Defaults to 100000.
273293
274294
Returns:
275295
:obj:`BERTUnfactoredDisambiguator`: Instance with loaded
@@ -293,69 +313,78 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
293313
with open(cache_path, 'rb') as f:
294314
ranking_cache = pickle.load(f)
295315
else:
296-
ranking_cache = {}
297-
298-
return BERTUnfactoredDisambiguator(model_path,
299-
analyzer,
300-
top=top,
301-
features=features,
302-
scorer=scorer,
303-
tie_breaker=tie_breaker,
304-
use_gpu=use_gpu,
305-
batch_size=batch_size,
306-
ranking_cache=ranking_cache)
316+
ranking_cache = None
317+
318+
return BERTUnfactoredDisambiguator(
319+
model_path,
320+
analyzer,
321+
top=top,
322+
features=features,
323+
scorer=scorer,
324+
tie_breaker=tie_breaker,
325+
use_gpu=use_gpu,
326+
batch_size=batch_size,
327+
ranking_cache=ranking_cache,
328+
ranking_cache_size=ranking_cache_size)
307329

330+
@staticmethod
308331
def pretrained_from_config(config, top=1, use_gpu=True, batch_size=32,
309-
cache_size=10000, pretrained_cache=True):
310-
"""Load a pre-trained model from a config file.
311-
312-
Args:
313-
config (:obj:`str`): Config file that defines the model
314-
details. Defaults to `None`.
315-
top (:obj:`int`, optional): The maximum number of top analyses
316-
to return. Defaults to 1.
317-
use_gpu (:obj:`bool`, optional): The flag to use a GPU or not.
318-
Defaults to True.
319-
batch_size (:obj:`int`, optional): The batch size.
320-
Defaults to 32.
321-
cache_size (:obj:`int`, optional): If greater than zero, then
322-
the analyzer will cache the analyses for the cache_size
323-
most frequent words, otherwise no analyses will be cached.
324-
Defaults to 100000.
325-
pretrained_cache (:obj:`bool`, optional): The flag to use a
326-
pretrained cache that stores ranked analyses.
327-
Defaults to True.
332+
cache_size=10000, pretrained_cache=True,
333+
ranking_cache_size=100000):
334+
"""Load a pre-trained model from a config file.
328335
329-
Returns:
330-
:obj:`BERTUnfactoredDisambiguator`: Instance with loaded
331-
pre-trained model.
332-
"""
333-
334-
model_config = _read_json(config)
335-
model_path = model_config['model_path']
336-
features = FEATURE_SET_MAP[model_config['feature']]
337-
db = MorphologyDB(model_config['db_path'], 'a')
338-
analyzer = Analyzer(db,
339-
backoff=model_config['backoff'],
340-
cache_size=cache_size)
341-
scorer = model_config['scorer']
342-
tie_breaker = model_config['tie_breaker']
343-
if pretrained_cache:
344-
cache_path = model_config['ranking_cache']
345-
with open(cache_path, 'rb') as f:
346-
ranking_cache = pickle.load(f)
347-
else:
348-
ranking_cache = {}
349-
350-
return BERTUnfactoredDisambiguator(model_path,
351-
analyzer,
352-
top=top,
353-
features=features,
354-
scorer=scorer,
355-
tie_breaker=tie_breaker,
356-
use_gpu=use_gpu,
357-
batch_size=batch_size,
358-
ranking_cache=ranking_cache)
336+
Args:
337+
config (:obj:`str`): Config file that defines the model details.
338+
Defaults to `None`.
339+
top (:obj:`int`, optional): The maximum number of top analyses
340+
to return. Defaults to 1.
341+
use_gpu (:obj:`bool`, optional): The flag to use a GPU or not.
342+
Defaults to True.
343+
batch_size (:obj:`int`, optional): The batch size. Defaults to 32.
344+
cache_size (:obj:`int`, optional): If greater than zero, then
345+
the analyzer will cache the analyses for the cache_size
346+
most frequent words, otherwise no analyses will be cached.
347+
Defaults to 100000.
348+
pretrained_cache (:obj:`bool`, optional): The flag to use a
349+
pretrained cache that stores ranked analyses.
350+
Defaults to True.
351+
ranking_cache_size (:obj:`int`, optional): The number of unique
352+
word disambiguations to cache. If 0, no ranked analyses will be
353+
cached. The cache uses a least-frequently-used eviction policy.
354+
Defaults to 100000.
355+
356+
Returns:
357+
:obj:`BERTUnfactoredDisambiguator`: Instance with loaded
358+
pre-trained model.
359+
"""
360+
361+
model_config = _read_json(config)
362+
model_path = model_config['model_path']
363+
features = FEATURE_SET_MAP[model_config['feature']]
364+
db = MorphologyDB(model_config['db_path'], 'a')
365+
analyzer = Analyzer(db,
366+
backoff=model_config['backoff'],
367+
cache_size=cache_size)
368+
scorer = model_config['scorer']
369+
tie_breaker = model_config['tie_breaker']
370+
if pretrained_cache:
371+
cache_path = model_config['ranking_cache']
372+
with open(cache_path, 'rb') as f:
373+
ranking_cache = pickle.load(f)
374+
else:
375+
ranking_cache = None
376+
377+
return BERTUnfactoredDisambiguator(
378+
model_path,
379+
analyzer,
380+
top=top,
381+
features=features,
382+
scorer=scorer,
383+
tie_breaker=tie_breaker,
384+
use_gpu=use_gpu,
385+
batch_size=batch_size,
386+
ranking_cache=ranking_cache,
387+
ranking_cache_size=ranking_cache_size)
359388

360389
def _predict_sentences(self, sentences):
361390
"""Predict the morphosyntactic labels of a list of sentences.
@@ -447,6 +476,11 @@ def _scored_analyses(self, word_dd, prediction):
447476
return scored_analyses[:self._top]
448477

449478
def _disambiguate_word(self, word, pred):
479+
scored_analyses = self._scored_analyses(word, pred)
480+
481+
return DisambiguatedWord(word, scored_analyses)
482+
483+
def _disambiguate_word_cached(self, word, pred):
450484
# Create a key for caching scored analysis given word and bert
451485
# predictions
452486
key = (word, tuple(pred[feat] for feat in self._features))
@@ -487,7 +521,7 @@ def disambiguate(self, sentence):
487521

488522
predictions = self._predict_sentence(sentence)
489523

490-
return [self._disambiguate_word(w, p)
524+
return [self._disambiguate_word_fn(w, p)
491525
for (w, p) in zip(sentence, predictions)]
492526

493527
def disambiguate_sentences(self, sentences):
@@ -507,7 +541,7 @@ def disambiguate_sentences(self, sentences):
507541

508542
for sentence, prediction in zip(sentences, predictions):
509543
disambiguated_sentence = [
510-
self._disambiguate_word(w, p)
544+
self._disambiguate_word_fn(w, p)
511545
for (w, p) in zip(sentence, prediction)
512546
]
513547
disambiguated_sentences.append(disambiguated_sentence)

0 commit comments

Comments
 (0)