Skip to content

Commit ac0f4f7

Browse files
committed
Fix style
1 parent bb57113 commit ac0f4f7

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

camel_tools/disambig/bert/__init__.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,21 @@
3333
from torch.utils.data import DataLoader
3434
from transformers import BertForTokenClassification, BertTokenizer
3535

36+
from camel_tools.data import CATALOGUE
3637
from camel_tools.morphology.database import MorphologyDB
3738
from camel_tools.morphology.analyzer import Analyzer
38-
39-
from camel_tools.data import CATALOGUE
4039
from camel_tools.disambig.common import Disambiguator, DisambiguatedWord
4140
from camel_tools.disambig.common import ScoredAnalysis
42-
4341
from camel_tools.disambig.bert.bert_morph_dataset import MorphDataset
4442
from camel_tools.disambig.score_function import score_analysis_uniform
4543
from camel_tools.disambig.score_function import FEATURE_SET_MAP
4644

4745

46+
__all__ = [
47+
'BERTUnfactoredDisambiguator',
48+
]
49+
50+
4851
_SCORING_FUNCTION_MAP = {
4952
'uniform': score_analysis_uniform
5053
}
@@ -219,8 +222,8 @@ def _predict_sentences(self, sentences):
219222
for feat in pred.split('__'):
220223
f, v = feat.split(':')
221224
d[f] = v
222-
d['lex'] = word # copy the word when analyzer is not used
223-
d['diac'] = word # copy the word when analyzer is not used
225+
d['lex'] = word # Copy the word when analyzer is not used
226+
d['diac'] = word # Copy the word when analyzer is not used
224227
parsed_prediction.append(d)
225228
parsed_predictions.append(parsed_prediction)
226229

@@ -246,8 +249,8 @@ def _predict_sentence(self, sentence):
246249
for feat in pred.split('__'):
247250
f, v = feat.split(':')
248251
d[f] = v
249-
d['lex'] = word # copy the word when analyzer is not used
250-
d['diac'] = word # copy the word when analyzer is not used
252+
d['lex'] = word # Copy the word when analyzer is not used
253+
d['diac'] = word # Copy the word when analyzer is not used
251254
parsed_predictions.append(d)
252255

253256
return parsed_predictions
@@ -257,7 +260,7 @@ def _scored_analyses(self, word_dd, prediction):
257260
analyses = self._analyzer.analyze(word_dd)
258261

259262
if len(analyses) == 0:
260-
# if the word is not found in the analyzer,
263+
# If the word is not found in the analyzer,
261264
# return the predictions from BERT
262265
return [ScoredAnalysis(0, bert_analysis)]
263266

@@ -273,13 +276,13 @@ def _scored_analyses(self, word_dd, prediction):
273276
scored_analyses = [ScoredAnalysis(s[0] / max_score, s[1])
274277
for s in scored]
275278
else:
276-
# if the max score is 0, do not divide
279+
# If the max score is 0, do not divide
277280
scored_analyses = [ScoredAnalysis(s[0], s[1]) for s in scored]
278281

279282
return scored_analyses[:self._top]
280283

281284
def _disambiguate_word(self, word, pred):
282-
# create a key for caching scored analysis given word and bert
285+
# Create a key for caching scored analysis given word and bert
283286
# predictions
284287
key = (word, tuple(pred[feat] for feat in self.features))
285288
if key in self.ranking_cache:
@@ -363,10 +366,9 @@ def tag_sentences(self, sentences, use_analyzer=True):
363366
if not use_analyzer:
364367
return self._predict_sentences(sentences)
365368

366-
top = 0
367369
tagged_sentences = []
368370
for prediction in self.disambiguate_sentences(sentences):
369-
tagged_sentence = [a.analyses[top].analysis for a in prediction]
371+
tagged_sentence = [a.analyses[0].analysis for a in prediction]
370372
tagged_sentences.append(tagged_sentence)
371373

372374
return tagged_sentences
@@ -390,8 +392,7 @@ def tag_sentence(self, sentence, use_analyzer=True):
390392
if not use_analyzer:
391393
return self._predict_sentence(sentence)
392394

393-
top = 0
394-
return [a.analyses[top].analysis for a in self.disambiguate(sentence)]
395+
return [a.analyses[0].analysis for a in self.disambiguate(sentence)]
395396

396397
def all_feats(self):
397398
"""Return a set of all features produced by this disambiguator.
@@ -430,8 +431,7 @@ def __init__(self, model_path, use_gpu=True):
430431
self.labels_map = self.model.config.id2label
431432
self.use_gpu = use_gpu
432433

433-
@staticmethod
434-
def labels():
434+
def labels(self):
435435
"""Get the list of Morph labels returned by predictions.
436436
437437
Returns:

0 commit comments

Comments
 (0)