2727from pathlib import Path
2828import pickle
2929
30+ from cachetools import LFUCache
3031import numpy as np
3132import torch
3233import torch .nn as nn
@@ -228,14 +229,18 @@ class BERTUnfactoredDisambiguator(Disambiguator):
228229 use_gpu (:obj:`bool`, optional): The flag to use a GPU or not.
229230 Defaults to True.
230231 batch_size (:obj:`int`, optional): The batch size. Defaults to 32.
231- ranking_cache (:obj:`dict `, optional): The cache dictionary of
232+ ranking_cache (:obj:`LFUCache `, optional): The cache dictionary of
232233 pre-computed scored analyses. Defaults to `None`.
234+ ranking_cache_size (:obj:`int`, optional): The number of unique word
235+ disambiguations to cache. If 0, no ranked analyses will be cached.
236+ The cache uses a least-frequently-used eviction policy.
237+ Defaults to 100000.
233238 """
234239
235240 def __init__ (self , model_path , analyzer ,
236241 features = FEATURE_SET_MAP ['feats_14' ], top = 1 ,
237242 scorer = 'uniform' , tie_breaker = 'tag' , use_gpu = True ,
238- batch_size = 32 , ranking_cache = None ):
243+ batch_size = 32 , ranking_cache = None , ranking_cache_size = 100000 ):
239244 self ._model = {
240245 'unfactored' : _BERTFeatureTagger (model_path )
241246 }
@@ -246,12 +251,23 @@ def __init__(self, model_path, analyzer,
246251 self ._tie_breaker = tie_breaker
247252 self ._use_gpu = use_gpu
248253 self ._batch_size = batch_size
249- self ._ranking_cache = ranking_cache
250254 self ._mle = _read_json (f'{ model_path } /mle_model.json' )
251255
256+ if ranking_cache is None :
257+ if ranking_cache_size <= 0 :
258+ self ._ranking_cache = None
259+ self ._disambiguate_word_fn = self ._disambiguate_word
260+ else :
261+ self ._ranking_cache = LFUCache (ranking_cache_size )
262+ self ._disambiguate_word_fn = self ._disambiguate_word_cached
263+ else :
264+ self ._ranking_cache = ranking_cache
265+ self ._disambiguate_word_fn = self ._disambiguate_word_cached
266+
252267 @staticmethod
253268 def pretrained (model_name = 'msa' , top = 1 , use_gpu = True , batch_size = 32 ,
254- cache_size = 10000 , pretrained_cache = True ):
269+ cache_size = 10000 , pretrained_cache = True ,
270+ ranking_cache_size = 100000 ):
255271 """Load a pre-trained model provided with camel_tools.
256272
257273 Args:
@@ -270,6 +286,10 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
270286 pretrained_cache (:obj:`bool`, optional): The flag to use a
271287 pretrained cache that stores ranked analyses.
272288 Defaults to True.
289+ ranking_cache_size (:obj:`int`, optional): The number of unique
290+ word disambiguations to cache. If 0, no ranked analyses will be
291+ cached. The cache uses a least-frequently-used eviction policy.
292+ Defaults to 100000.
273293
274294 Returns:
275295 :obj:`BERTUnfactoredDisambiguator`: Instance with loaded
@@ -293,69 +313,78 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
293313 with open (cache_path , 'rb' ) as f :
294314 ranking_cache = pickle .load (f )
295315 else :
296- ranking_cache = {}
297-
298- return BERTUnfactoredDisambiguator (model_path ,
299- analyzer ,
300- top = top ,
301- features = features ,
302- scorer = scorer ,
303- tie_breaker = tie_breaker ,
304- use_gpu = use_gpu ,
305- batch_size = batch_size ,
306- ranking_cache = ranking_cache )
316+ ranking_cache = None
317+
318+ return BERTUnfactoredDisambiguator (
319+ model_path ,
320+ analyzer ,
321+ top = top ,
322+ features = features ,
323+ scorer = scorer ,
324+ tie_breaker = tie_breaker ,
325+ use_gpu = use_gpu ,
326+ batch_size = batch_size ,
327+ ranking_cache = ranking_cache ,
328+ ranking_cache_size = ranking_cache_size )
307329
330+ @staticmethod
308331 def pretrained_from_config (config , top = 1 , use_gpu = True , batch_size = 32 ,
309- cache_size = 10000 , pretrained_cache = True ):
310- """Load a pre-trained model from a config file.
311-
312- Args:
313- config (:obj:`str`): Config file that defines the model
314- details. Defaults to `None`.
315- top (:obj:`int`, optional): The maximum number of top analyses
316- to return. Defaults to 1.
317- use_gpu (:obj:`bool`, optional): The flag to use a GPU or not.
318- Defaults to True.
319- batch_size (:obj:`int`, optional): The batch size.
320- Defaults to 32.
321- cache_size (:obj:`int`, optional): If greater than zero, then
322- the analyzer will cache the analyses for the cache_size
323- most frequent words, otherwise no analyses will be cached.
324- Defaults to 100000.
325- pretrained_cache (:obj:`bool`, optional): The flag to use a
326- pretrained cache that stores ranked analyses.
327- Defaults to True.
332+ cache_size = 10000 , pretrained_cache = True ,
333+ ranking_cache_size = 100000 ):
334+ """Load a pre-trained model from a config file.
328335
329- Returns:
330- :obj:`BERTUnfactoredDisambiguator`: Instance with loaded
331- pre-trained model.
332- """
333-
334- model_config = _read_json (config )
335- model_path = model_config ['model_path' ]
336- features = FEATURE_SET_MAP [model_config ['feature' ]]
337- db = MorphologyDB (model_config ['db_path' ], 'a' )
338- analyzer = Analyzer (db ,
339- backoff = model_config ['backoff' ],
340- cache_size = cache_size )
341- scorer = model_config ['scorer' ]
342- tie_breaker = model_config ['tie_breaker' ]
343- if pretrained_cache :
344- cache_path = model_config ['ranking_cache' ]
345- with open (cache_path , 'rb' ) as f :
346- ranking_cache = pickle .load (f )
347- else :
348- ranking_cache = {}
349-
350- return BERTUnfactoredDisambiguator (model_path ,
351- analyzer ,
352- top = top ,
353- features = features ,
354- scorer = scorer ,
355- tie_breaker = tie_breaker ,
356- use_gpu = use_gpu ,
357- batch_size = batch_size ,
358- ranking_cache = ranking_cache )
336+ Args:
337+ config (:obj:`str`): Config file that defines the model details.
338+ Defaults to `None`.
339+ top (:obj:`int`, optional): The maximum number of top analyses
340+ to return. Defaults to 1.
341+ use_gpu (:obj:`bool`, optional): The flag to use a GPU or not.
342+ Defaults to True.
343+ batch_size (:obj:`int`, optional): The batch size. Defaults to 32.
344+ cache_size (:obj:`int`, optional): If greater than zero, then
345+ the analyzer will cache the analyses for the cache_size
346+ most frequent words, otherwise no analyses will be cached.
347+ Defaults to 100000.
348+ pretrained_cache (:obj:`bool`, optional): The flag to use a
349+ pretrained cache that stores ranked analyses.
350+ Defaults to True.
351+ ranking_cache_size (:obj:`int`, optional): The number of unique
352+ word disambiguations to cache. If 0, no ranked analyses will be
353+ cached. The cache uses a least-frequently-used eviction policy.
354+ Defaults to 100000.
355+
356+ Returns:
357+ :obj:`BERTUnfactoredDisambiguator`: Instance with loaded
358+ pre-trained model.
359+ """
360+
361+ model_config = _read_json (config )
362+ model_path = model_config ['model_path' ]
363+ features = FEATURE_SET_MAP [model_config ['feature' ]]
364+ db = MorphologyDB (model_config ['db_path' ], 'a' )
365+ analyzer = Analyzer (db ,
366+ backoff = model_config ['backoff' ],
367+ cache_size = cache_size )
368+ scorer = model_config ['scorer' ]
369+ tie_breaker = model_config ['tie_breaker' ]
370+ if pretrained_cache :
371+ cache_path = model_config ['ranking_cache' ]
372+ with open (cache_path , 'rb' ) as f :
373+ ranking_cache = pickle .load (f )
374+ else :
375+ ranking_cache = None
376+
377+ return BERTUnfactoredDisambiguator (
378+ model_path ,
379+ analyzer ,
380+ top = top ,
381+ features = features ,
382+ scorer = scorer ,
383+ tie_breaker = tie_breaker ,
384+ use_gpu = use_gpu ,
385+ batch_size = batch_size ,
386+ ranking_cache = ranking_cache ,
387+ ranking_cache_size = ranking_cache_size )
359388
360389 def _predict_sentences (self , sentences ):
361390 """Predict the morphosyntactic labels of a list of sentences.
@@ -447,6 +476,11 @@ def _scored_analyses(self, word_dd, prediction):
447476 return scored_analyses [:self ._top ]
448477
449478 def _disambiguate_word (self , word , pred ):
479+ scored_analyses = self ._scored_analyses (word , pred )
480+
481+ return DisambiguatedWord (word , scored_analyses )
482+
483+ def _disambiguate_word_cached (self , word , pred ):
450484 # Create a key for caching scored analysis given word and bert
451485 # predictions
452486 key = (word , tuple (pred [feat ] for feat in self ._features ))
@@ -487,7 +521,7 @@ def disambiguate(self, sentence):
487521
488522 predictions = self ._predict_sentence (sentence )
489523
490- return [self ._disambiguate_word (w , p )
524+ return [self ._disambiguate_word_fn (w , p )
491525 for (w , p ) in zip (sentence , predictions )]
492526
493527 def disambiguate_sentences (self , sentences ):
@@ -507,7 +541,7 @@ def disambiguate_sentences(self, sentences):
507541
508542 for sentence , prediction in zip (sentences , predictions ):
509543 disambiguated_sentence = [
510- self ._disambiguate_word (w , p )
544+ self ._disambiguate_word_fn (w , p )
511545 for (w , p ) in zip (sentence , prediction )
512546 ]
513547 disambiguated_sentences .append (disambiguated_sentence )
0 commit comments