3333from torch .utils .data import DataLoader
3434from transformers import BertForTokenClassification , BertTokenizer
3535
36+ from camel_tools .data import CATALOGUE
3637from camel_tools .morphology .database import MorphologyDB
3738from camel_tools .morphology .analyzer import Analyzer
38-
39- from camel_tools .data import CATALOGUE
4039from camel_tools .disambig .common import Disambiguator , DisambiguatedWord
4140from camel_tools .disambig .common import ScoredAnalysis
42-
4341from camel_tools .disambig .bert .bert_morph_dataset import MorphDataset
4442from camel_tools .disambig .score_function import score_analysis_uniform
4543from camel_tools .disambig .score_function import FEATURE_SET_MAP
4644
4745
46+ __all__ = [
47+ 'BERTUnfactoredDisambiguator' ,
48+ ]
49+
50+
4851_SCORING_FUNCTION_MAP = {
4952 'uniform' : score_analysis_uniform
5053}
@@ -219,8 +222,8 @@ def _predict_sentences(self, sentences):
219222 for feat in pred .split ('__' ):
220223 f , v = feat .split (':' )
221224 d [f ] = v
222- d ['lex' ] = word # copy the word when analyzer is not used
223- d ['diac' ] = word # copy the word when analyzer is not used
225+ d ['lex' ] = word # Copy the word when analyzer is not used
226+ d ['diac' ] = word # Copy the word when analyzer is not used
224227 parsed_prediction .append (d )
225228 parsed_predictions .append (parsed_prediction )
226229
@@ -246,8 +249,8 @@ def _predict_sentence(self, sentence):
246249 for feat in pred .split ('__' ):
247250 f , v = feat .split (':' )
248251 d [f ] = v
249- d ['lex' ] = word # copy the word when analyzer is not used
250- d ['diac' ] = word # copy the word when analyzer is not used
252+ d ['lex' ] = word # Copy the word when analyzer is not used
253+ d ['diac' ] = word # Copy the word when analyzer is not used
251254 parsed_predictions .append (d )
252255
253256 return parsed_predictions
@@ -257,7 +260,7 @@ def _scored_analyses(self, word_dd, prediction):
257260 analyses = self ._analyzer .analyze (word_dd )
258261
259262 if len (analyses ) == 0 :
260- # if the word is not found in the analyzer,
263+ # If the word is not found in the analyzer,
261264 # return the predictions from BERT
262265 return [ScoredAnalysis (0 , bert_analysis )]
263266
@@ -273,13 +276,13 @@ def _scored_analyses(self, word_dd, prediction):
273276 scored_analyses = [ScoredAnalysis (s [0 ] / max_score , s [1 ])
274277 for s in scored ]
275278 else :
276- # if the max score is 0, do not divide
279+ # If the max score is 0, do not divide
277280 scored_analyses = [ScoredAnalysis (s [0 ], s [1 ]) for s in scored ]
278281
279282 return scored_analyses [:self ._top ]
280283
281284 def _disambiguate_word (self , word , pred ):
282- # create a key for caching scored analysis given word and bert
285+ # Create a key for caching scored analysis given word and bert
283286 # predictions
284287 key = (word , tuple (pred [feat ] for feat in self .features ))
285288 if key in self .ranking_cache :
@@ -363,10 +366,9 @@ def tag_sentences(self, sentences, use_analyzer=True):
363366 if not use_analyzer :
364367 return self ._predict_sentences (sentences )
365368
366- top = 0
367369 tagged_sentences = []
368370 for prediction in self .disambiguate_sentences (sentences ):
369- tagged_sentence = [a .analyses [top ].analysis for a in prediction ]
371+ tagged_sentence = [a .analyses [0 ].analysis for a in prediction ]
370372 tagged_sentences .append (tagged_sentence )
371373
372374 return tagged_sentences
@@ -390,8 +392,7 @@ def tag_sentence(self, sentence, use_analyzer=True):
390392 if not use_analyzer :
391393 return self ._predict_sentence (sentence )
392394
393- top = 0
394- return [a .analyses [top ].analysis for a in self .disambiguate (sentence )]
395+ return [a .analyses [0 ].analysis for a in self .disambiguate (sentence )]
395396
396397 def all_feats (self ):
397398 """Return a set of all features produced by this disambiguator.
@@ -430,8 +431,7 @@ def __init__(self, model_path, use_gpu=True):
430431 self .labels_map = self .model .config .id2label
431432 self .use_gpu = use_gpu
432433
433- @staticmethod
434- def labels ():
434+ def labels (self ):
435435 """Get the list of Morph labels returned by predictions.
436436
437437 Returns:
0 commit comments