Skip to content

Commit 01438bb

Browse files
authored
Merge pull request #86 from go-inoue/ranking_cache
Add ranking cache
2 parents c92d2c1 + 8dda232 commit 01438bb

File tree

1 file changed

+115
-67
lines changed

1 file changed

+115
-67
lines changed

camel_tools/disambig/bert/unfactored.py

Lines changed: 115 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pathlib import Path
2828
import pickle
2929

30+
from cachetools import LFUCache
3031
import numpy as np
3132
import torch
3233
import torch.nn as nn
@@ -228,14 +229,18 @@ class BERTUnfactoredDisambiguator(Disambiguator):
228229
use_gpu (:obj:`bool`, optional): The flag to use a GPU or not.
229230
Defaults to True.
230231
batch_size (:obj:`int`, optional): The batch size. Defaults to 32.
231-
ranking_cache (:obj:`dict`, optional): The cache dictionary of
232-
pre-computed scored analyses. Defaults to `None`.
232+
ranking_cache (:obj:`LFUCache`, optional): The cache of pre-computed
233+
scored analyses. Defaults to `None`.
234+
ranking_cache_size (:obj:`int`, optional): The number of unique word
235+
disambiguations to cache. If 0, no ranked analyses will be cached.
236+
The cache uses a least-frequently-used eviction policy.
237+
Defaults to 100000.
233238
"""
234239

235240
def __init__(self, model_path, analyzer,
236241
features=FEATURE_SET_MAP['feats_14'], top=1,
237242
scorer='uniform', tie_breaker='tag', use_gpu=True,
238-
batch_size=32, ranking_cache=None):
243+
batch_size=32, ranking_cache=None, ranking_cache_size=100000):
239244
self._model = {
240245
'unfactored': _BERTFeatureTagger(model_path)
241246
}
@@ -246,18 +251,23 @@ def __init__(self, model_path, analyzer,
246251
self._tie_breaker = tie_breaker
247252
self._use_gpu = use_gpu
248253
self._batch_size = batch_size
254+
self._mle = _read_json(f'{model_path}/mle_model.json')
249255

250-
if ranking_cache is not None:
251-
with open(ranking_cache, 'rb') as f:
252-
self._ranking_cache = pickle.load(f)
256+
if ranking_cache is None:
257+
if ranking_cache_size <= 0:
258+
self._ranking_cache = None
259+
self._disambiguate_word_fn = self._disambiguate_word
260+
else:
261+
self._ranking_cache = LFUCache(ranking_cache_size)
262+
self._disambiguate_word_fn = self._disambiguate_word_cached
253263
else:
254-
self._ranking_cache = {}
255-
256-
self._mle = _read_json(f'{model_path}/mle_model.json')
264+
self._ranking_cache = ranking_cache
265+
self._disambiguate_word_fn = self._disambiguate_word_cached
257266

258267
@staticmethod
259268
def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
260-
cache_size=10000):
269+
cache_size=10000, pretrained_cache=True,
270+
ranking_cache_size=100000):
261271
"""Load a pre-trained model provided with camel_tools.
262272
263273
Args:
@@ -273,6 +283,14 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
273283
the analyzer will cache the analyses for the cache_size most
274284
frequent words, otherwise no analyses will be cached.
275285
Defaults to 100000.
286+
pretrained_cache (:obj:`bool`, optional): The flag to use a
287+
pretrained cache that stores ranked analyses.
288+
Defaults to True.
289+
ranking_cache_size (:obj:`int`, optional): The number of unique
290+
word disambiguations to cache. If 0, no ranked analyses will be
291+
cached. The cache uses a least-frequently-used eviction policy.
292+
This argument is ignored if pretrained_cache is True.
293+
Defaults to 100000.
276294
277295
Returns:
278296
:obj:`BERTUnfactoredDisambiguator`: Instance with loaded
@@ -289,61 +307,86 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
289307
cache_size=cache_size)
290308
scorer = model_config['scorer']
291309
tie_breaker = model_config['tie_breaker']
292-
ranking_cache = model_config['ranking_cache']
293-
294-
return BERTUnfactoredDisambiguator(model_path,
295-
analyzer,
296-
top=top,
297-
features=features,
298-
scorer=scorer,
299-
tie_breaker=tie_breaker,
300-
use_gpu=use_gpu,
301-
batch_size=batch_size,
302-
ranking_cache=ranking_cache)
303-
304-
def pretrained_from_config(config, top=1, use_gpu=True, batch_size=32,
305-
cache_size=10000):
306-
"""Load a pre-trained model from a config file.
307-
308-
Args:
309-
config (:obj:`str`): Config file that defines the model
310-
details. Defaults to `None`.
311-
top (:obj:`int`, optional): The maximum number of top analyses
312-
to return. Defaults to 1.
313-
use_gpu (:obj:`bool`, optional): The flag to use a GPU or not.
314-
Defaults to True.
315-
batch_size (:obj:`int`, optional): The batch size.
316-
Defaults to 32.
317-
cache_size (:obj:`int`, optional): If greater than zero, then
318-
the analyzer will cache the analyses for the cache_size
319-
most frequent words, otherwise no analyses will be cached.
320-
Defaults to 100000.
321-
322-
Returns:
323-
:obj:`BERTUnfactoredDisambiguator`: Instance with loaded
324-
pre-trained model.
325-
"""
326-
327-
model_config = _read_json(config)
328-
model_path = model_config['model_path']
329-
features = FEATURE_SET_MAP[model_config['feature']]
330-
db = MorphologyDB(model_config['db_path'], 'a')
331-
analyzer = Analyzer(db,
332-
backoff=model_config['backoff'],
333-
cache_size=cache_size)
334-
scorer = model_config['scorer']
335-
tie_breaker = model_config['tie_breaker']
336-
ranking_cache = model_config['ranking_cache']
337-
338-
return BERTUnfactoredDisambiguator(model_path,
339-
analyzer,
340-
top=top,
341-
features=features,
342-
scorer=scorer,
343-
tie_breaker=tie_breaker,
344-
use_gpu=use_gpu,
345-
batch_size=batch_size,
346-
ranking_cache=ranking_cache)
310+
if pretrained_cache:
311+
cache_info = CATALOGUE.get_dataset('DisambigRankingCache',
312+
model_config['ranking_cache'])
313+
cache_path = Path(cache_info.path, 'default_cache.pickle')
314+
with open(cache_path, 'rb') as f:
315+
ranking_cache = pickle.load(f)
316+
else:
317+
ranking_cache = None
318+
319+
return BERTUnfactoredDisambiguator(
320+
model_path,
321+
analyzer,
322+
top=top,
323+
features=features,
324+
scorer=scorer,
325+
tie_breaker=tie_breaker,
326+
use_gpu=use_gpu,
327+
batch_size=batch_size,
328+
ranking_cache=ranking_cache,
329+
ranking_cache_size=ranking_cache_size)
330+
331+
@staticmethod
332+
def _pretrained_from_config(config, top=1, use_gpu=True, batch_size=32,
333+
cache_size=10000, pretrained_cache=True,
334+
ranking_cache_size=100000):
335+
"""Load a pre-trained model from a config file.
336+
337+
Args:
338+
config (:obj:`str`): Config file that defines the model details.
339+
Defaults to `None`.
340+
top (:obj:`int`, optional): The maximum number of top analyses
341+
to return. Defaults to 1.
342+
use_gpu (:obj:`bool`, optional): The flag to use a GPU or not.
343+
Defaults to True.
344+
batch_size (:obj:`int`, optional): The batch size. Defaults to 32.
345+
cache_size (:obj:`int`, optional): If greater than zero, then
346+
the analyzer will cache the analyses for the cache_size
347+
most frequent words, otherwise no analyses will be cached.
348+
Defaults to 100000.
349+
pretrained_cache (:obj:`bool`, optional): The flag to use a
350+
pretrained cache that stores ranked analyses.
351+
Defaults to True.
352+
ranking_cache_size (:obj:`int`, optional): The number of unique
353+
word disambiguations to cache. If 0, no ranked analyses will be
354+
cached. The cache uses a least-frequently-used eviction policy.
355+
This argument is ignored if pretrained_cache is True.
356+
Defaults to 100000.
357+
358+
Returns:
359+
:obj:`BERTUnfactoredDisambiguator`: Instance with loaded
360+
pre-trained model.
361+
"""
362+
363+
model_config = _read_json(config)
364+
model_path = model_config['model_path']
365+
features = FEATURE_SET_MAP[model_config['feature']]
366+
db = MorphologyDB(model_config['db_path'], 'a')
367+
analyzer = Analyzer(db,
368+
backoff=model_config['backoff'],
369+
cache_size=cache_size)
370+
scorer = model_config['scorer']
371+
tie_breaker = model_config['tie_breaker']
372+
if pretrained_cache:
373+
cache_path = model_config['ranking_cache']
374+
with open(cache_path, 'rb') as f:
375+
ranking_cache = pickle.load(f)
376+
else:
377+
ranking_cache = None
378+
379+
return BERTUnfactoredDisambiguator(
380+
model_path,
381+
analyzer,
382+
top=top,
383+
features=features,
384+
scorer=scorer,
385+
tie_breaker=tie_breaker,
386+
use_gpu=use_gpu,
387+
batch_size=batch_size,
388+
ranking_cache=ranking_cache,
389+
ranking_cache_size=ranking_cache_size)
347390

348391
def _predict_sentences(self, sentences):
349392
"""Predict the morphosyntactic labels of a list of sentences.
@@ -435,6 +478,11 @@ def _scored_analyses(self, word_dd, prediction):
435478
return scored_analyses[:self._top]
436479

437480
def _disambiguate_word(self, word, pred):
481+
scored_analyses = self._scored_analyses(word, pred)
482+
483+
return DisambiguatedWord(word, scored_analyses)
484+
485+
def _disambiguate_word_cached(self, word, pred):
438486
# Create a key for caching scored analysis given word and bert
439487
# predictions
440488
key = (word, tuple(pred[feat] for feat in self._features))
@@ -475,7 +523,7 @@ def disambiguate(self, sentence):
475523

476524
predictions = self._predict_sentence(sentence)
477525

478-
return [self._disambiguate_word(w, p)
526+
return [self._disambiguate_word_fn(w, p)
479527
for (w, p) in zip(sentence, predictions)]
480528

481529
def disambiguate_sentences(self, sentences):
@@ -495,7 +543,7 @@ def disambiguate_sentences(self, sentences):
495543

496544
for sentence, prediction in zip(sentences, predictions):
497545
disambiguated_sentence = [
498-
self._disambiguate_word(w, p)
546+
self._disambiguate_word_fn(w, p)
499547
for (w, p) in zip(sentence, prediction)
500548
]
501549
disambiguated_sentences.append(disambiguated_sentence)

0 commit comments

Comments
 (0)