2727from  pathlib  import  Path 
2828import  pickle 
2929
30+ from  cachetools  import  LFUCache 
3031import  numpy  as  np 
3132import  torch 
3233import  torch .nn  as  nn 
@@ -228,14 +229,18 @@ class BERTUnfactoredDisambiguator(Disambiguator):
228229        use_gpu (:obj:`bool`, optional): The flag to use a GPU or not. 
229230            Defaults to True. 
230231        batch_size (:obj:`int`, optional): The batch size. Defaults to 32. 
231-         ranking_cache (:obj:`dict`, optional): The cache dictionary of 
232-             pre-computed scored analyses. Defaults to `None`. 
232+         ranking_cache (:obj:`LFUCache`, optional): The cache of pre-computed 
233+             scored analyses. Defaults to `None`. 
234+         ranking_cache_size (:obj:`int`, optional): The number of unique word 
235+             disambiguations to cache. If 0, no ranked analyses will be cached. 
236+             The cache uses a least-frequently-used eviction policy. 
237+             Defaults to 100000. 
233238    """ 
234239
235240    def  __init__ (self , model_path , analyzer ,
236241                 features = FEATURE_SET_MAP ['feats_14' ], top = 1 ,
237242                 scorer = 'uniform' , tie_breaker = 'tag' , use_gpu = True ,
238-                  batch_size = 32 , ranking_cache = None ):
243+                  batch_size = 32 , ranking_cache = None ,  ranking_cache_size = 100000 ):
239244        self ._model  =  {
240245            'unfactored' : _BERTFeatureTagger (model_path )
241246        }
@@ -246,18 +251,23 @@ def __init__(self, model_path, analyzer,
246251        self ._tie_breaker  =  tie_breaker 
247252        self ._use_gpu  =  use_gpu 
248253        self ._batch_size  =  batch_size 
254+         self ._mle  =  _read_json (f'{ model_path }  /mle_model.json' )
249255
250-         if  ranking_cache  is  not   None :
251-             with  open (ranking_cache , 'rb' ) as  f :
252-                 self ._ranking_cache  =  pickle .load (f )
256+         if  ranking_cache  is  None :
257+             if  ranking_cache_size  <=  0 :
258+                 self ._ranking_cache  =  None 
259+                 self ._disambiguate_word_fn  =  self ._disambiguate_word 
260+             else :
261+                 self ._ranking_cache  =  LFUCache (ranking_cache_size )
262+                 self ._disambiguate_word_fn  =  self ._disambiguate_word_cached 
253263        else :
254-             self ._ranking_cache  =  {}
255- 
256-         self ._mle  =  _read_json (f'{ model_path }  /mle_model.json' )
264+             self ._ranking_cache  =  ranking_cache 
265+             self ._disambiguate_word_fn  =  self ._disambiguate_word_cached 
257266
258267    @staticmethod  
259268    def  pretrained (model_name = 'msa' , top = 1 , use_gpu = True , batch_size = 32 ,
260-                    cache_size = 10000 ):
269+                    cache_size = 10000 , pretrained_cache = True ,
270+                    ranking_cache_size = 100000 ):
261271        """Load a pre-trained model provided with camel_tools. 
262272
263273        Args: 
@@ -273,6 +283,14 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
273283                the analyzer will cache the analyses for the cache_size most 
274284                frequent words, otherwise no analyses will be cached. 
275285                Defaults to 100000. 
286+             pretrained_cache (:obj:`bool`, optional): The flag to use a 
287+                     pretrained cache that stores ranked analyses. 
288+                     Defaults to True. 
289+             ranking_cache_size (:obj:`int`, optional): The number of unique 
290+                 word disambiguations to cache. If 0, no ranked analyses will be 
291+                 cached. The cache uses a least-frequently-used eviction policy. 
292+                 This argument is ignored if pretrained_cache is True. 
293+                 Defaults to 100000. 
276294
277295        Returns: 
278296            :obj:`BERTUnfactoredDisambiguator`: Instance with loaded 
@@ -289,61 +307,86 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
289307                            cache_size = cache_size )
290308        scorer  =  model_config ['scorer' ]
291309        tie_breaker  =  model_config ['tie_breaker' ]
292-         ranking_cache  =  model_config ['ranking_cache' ]
293- 
294-         return  BERTUnfactoredDisambiguator (model_path ,
295-                                            analyzer ,
296-                                            top = top ,
297-                                            features = features ,
298-                                            scorer = scorer ,
299-                                            tie_breaker = tie_breaker ,
300-                                            use_gpu = use_gpu ,
301-                                            batch_size = batch_size ,
302-                                            ranking_cache = ranking_cache )
303- 
304-     def  pretrained_from_config (config , top = 1 , use_gpu = True , batch_size = 32 ,
305-                                cache_size = 10000 ):
306-             """Load a pre-trained model from a config file. 
307- 
308-             Args: 
309-                 config (:obj:`str`): Config file that defines the model 
310-                     details. Defaults to `None`. 
311-                 top (:obj:`int`, optional): The maximum number of top analyses 
312-                     to return. Defaults to 1. 
313-                 use_gpu (:obj:`bool`, optional): The flag to use a GPU or not. 
314-                     Defaults to True. 
315-                 batch_size (:obj:`int`, optional): The batch size. 
316-                     Defaults to 32. 
317-                 cache_size (:obj:`int`, optional): If greater than zero, then 
318-                     the analyzer will cache the analyses for the cache_size 
319-                     most frequent words, otherwise no analyses will be cached. 
320-                     Defaults to 100000. 
321- 
322-             Returns: 
323-                 :obj:`BERTUnfactoredDisambiguator`: Instance with loaded 
324-                 pre-trained model. 
325-             """ 
326- 
327-             model_config  =  _read_json (config )
328-             model_path  =  model_config ['model_path' ]
329-             features  =  FEATURE_SET_MAP [model_config ['feature' ]]
330-             db  =  MorphologyDB (model_config ['db_path' ], 'a' )
331-             analyzer  =  Analyzer (db ,
332-                                 backoff = model_config ['backoff' ],
333-                                 cache_size = cache_size )
334-             scorer  =  model_config ['scorer' ]
335-             tie_breaker  =  model_config ['tie_breaker' ]
336-             ranking_cache  =  model_config ['ranking_cache' ]
337- 
338-             return  BERTUnfactoredDisambiguator (model_path ,
339-                                                analyzer ,
340-                                                top = top ,
341-                                                features = features ,
342-                                                scorer = scorer ,
343-                                                tie_breaker = tie_breaker ,
344-                                                use_gpu = use_gpu ,
345-                                                batch_size = batch_size ,
346-                                                ranking_cache = ranking_cache )
310+         if  pretrained_cache :
311+             cache_info  =  CATALOGUE .get_dataset ('DisambigRankingCache' ,
312+                                                model_config ['ranking_cache' ])
313+             cache_path  =  Path (cache_info .path , 'default_cache.pickle' )
314+             with  open (cache_path , 'rb' ) as  f :
315+                 ranking_cache  =  pickle .load (f )
316+         else :
317+             ranking_cache  =  None 
318+ 
319+         return  BERTUnfactoredDisambiguator (
320+             model_path ,
321+             analyzer ,
322+             top = top ,
323+             features = features ,
324+             scorer = scorer ,
325+             tie_breaker = tie_breaker ,
326+             use_gpu = use_gpu ,
327+             batch_size = batch_size ,
328+             ranking_cache = ranking_cache ,
329+             ranking_cache_size = ranking_cache_size )
330+ 
331+     @staticmethod  
332+     def  _pretrained_from_config (config , top = 1 , use_gpu = True , batch_size = 32 ,
333+                                cache_size = 10000 , pretrained_cache = True ,
334+                                ranking_cache_size = 100000 ):
335+         """Load a pre-trained model from a config file. 
336+ 
337+         Args: 
338+             config (:obj:`str`): Config file that defines the model details. 
339+                 Defaults to `None`. 
340+             top (:obj:`int`, optional): The maximum number of top analyses 
341+                 to return. Defaults to 1. 
342+             use_gpu (:obj:`bool`, optional): The flag to use a GPU or not. 
343+                 Defaults to True. 
344+             batch_size (:obj:`int`, optional): The batch size. Defaults to 32. 
345+             cache_size (:obj:`int`, optional): If greater than zero, then 
346+                 the analyzer will cache the analyses for the cache_size 
347+                 most frequent words, otherwise no analyses will be cached. 
348+                 Defaults to 100000. 
349+             pretrained_cache (:obj:`bool`, optional): The flag to use a 
350+                 pretrained cache that stores ranked analyses. 
351+                 Defaults to True. 
352+             ranking_cache_size (:obj:`int`, optional): The number of unique 
353+                 word disambiguations to cache. If 0, no ranked analyses will be 
354+                 cached. The cache uses a least-frequently-used eviction policy. 
355+                 This argument is ignored if pretrained_cache is True. 
356+                 Defaults to 100000. 
357+ 
358+         Returns: 
359+             :obj:`BERTUnfactoredDisambiguator`: Instance with loaded 
360+             pre-trained model. 
361+         """ 
362+ 
363+         model_config  =  _read_json (config )
364+         model_path  =  model_config ['model_path' ]
365+         features  =  FEATURE_SET_MAP [model_config ['feature' ]]
366+         db  =  MorphologyDB (model_config ['db_path' ], 'a' )
367+         analyzer  =  Analyzer (db ,
368+                             backoff = model_config ['backoff' ],
369+                             cache_size = cache_size )
370+         scorer  =  model_config ['scorer' ]
371+         tie_breaker  =  model_config ['tie_breaker' ]
372+         if  pretrained_cache :
373+             cache_path  =  model_config ['ranking_cache' ]
374+             with  open (cache_path , 'rb' ) as  f :
375+                 ranking_cache  =  pickle .load (f )
376+         else :
377+             ranking_cache  =  None 
378+ 
379+         return  BERTUnfactoredDisambiguator (
380+             model_path ,
381+             analyzer ,
382+             top = top ,
383+             features = features ,
384+             scorer = scorer ,
385+             tie_breaker = tie_breaker ,
386+             use_gpu = use_gpu ,
387+             batch_size = batch_size ,
388+             ranking_cache = ranking_cache ,
389+             ranking_cache_size = ranking_cache_size )
347390
348391    def  _predict_sentences (self , sentences ):
349392        """Predict the morphosyntactic labels of a list of sentences. 
@@ -435,6 +478,11 @@ def _scored_analyses(self, word_dd, prediction):
435478        return  scored_analyses [:self ._top ]
436479
437480    def  _disambiguate_word (self , word , pred ):
481+         scored_analyses  =  self ._scored_analyses (word , pred )
482+ 
483+         return  DisambiguatedWord (word , scored_analyses )
484+ 
485+     def  _disambiguate_word_cached (self , word , pred ):
438486        # Create a key for caching scored analysis given word and bert 
439487        # predictions 
440488        key  =  (word , tuple (pred [feat ] for  feat  in  self ._features ))
@@ -475,7 +523,7 @@ def disambiguate(self, sentence):
475523
476524        predictions  =  self ._predict_sentence (sentence )
477525
478-         return  [self ._disambiguate_word (w , p )
526+         return  [self ._disambiguate_word_fn (w , p )
479527                for  (w , p ) in  zip (sentence , predictions )]
480528
481529    def  disambiguate_sentences (self , sentences ):
@@ -495,7 +543,7 @@ def disambiguate_sentences(self, sentences):
495543
496544        for  sentence , prediction  in  zip (sentences , predictions ):
497545            disambiguated_sentence  =  [
498-                 self ._disambiguate_word (w , p )
546+                 self ._disambiguate_word_fn (w , p )
499547                for  (w , p ) in  zip (sentence , prediction )
500548            ]
501549            disambiguated_sentences .append (disambiguated_sentence )
0 commit comments