Skip to content

Commit f4e8e07

Browse files
committed
Improved analysis sorting using pos-lex log prob.
1 parent 249288e commit f4e8e07

File tree

3 files changed

+61
-14
lines changed

3 files changed

+61
-14
lines changed

camel_tools/disambig/bert/unfactored.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -476,16 +476,23 @@ def _scored_analyses(self, word_dd, prediction):
476476
tie_breaker=self._tie_breaker,
477477
features=self._features), a)
478478
for a in analyses]
479-
scored.sort(key=lambda s: (-s[0], s[1]['diac']))
479+
# scored.sort(key=lambda s: (-s[0], s[1]['diac']))
480480

481481
max_score = max(s[0] for s in scored)
482482

483-
if max_score != 0:
484-
scored_analyses = [ScoredAnalysis(s[0] / max_score, s[1])
485-
for s in scored]
486-
else:
487-
# If the max score is 0, do not divide
488-
scored_analyses = [ScoredAnalysis(0, s[1]) for s in scored]
483+
if max_score == 0:
484+
max_score = 1
485+
486+
scored_analyses = [
487+
ScoredAnalysis(
488+
s / max_score, # score
489+
a, # analysis
490+
a['diac'], # diac
491+
a.get('pos_lex_logprob', -99), # pos_lex_logprob
492+
a.get('lex_logprob', -99), # lex_logprob
493+
) for s, a in scored]
494+
495+
scored_analyses.sort()
489496

490497
return scored_analyses[:self._top]
491498

camel_tools/disambig/common.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,49 @@
3232
from collections import namedtuple
3333

3434

35-
class ScoredAnalysis(namedtuple('ScoredAnalysis', ['score', 'analysis'])):
35+
class ScoredAnalysis(namedtuple('ScoredAnalysis',
36+
[
37+
'score',
38+
'analysis',
39+
'diac',
40+
'pos_lex_logprob',
41+
'lex_logprob'
42+
])):
3643
"""A named tuple containing an analysis and its score.
3744
3845
Attributes:
39-
score (:obj:`float`): The score of a given analysis.
46+
score (:obj:`float`): The overall score of the analysis.
4047
4148
analysis (:obj:`dict`): The analysis dictionary.
42-
See :doc:`/reference/camel_morphology_features` for more information on
43-
features and their values.
49+
See :doc:`/reference/camel_morphology_features` for more
50+
information on features and their values.
51+
52+
diac (:obj:`str`): The diactrized form of the associated analysis.
53+
Used for tie-breaking equally scored analyses.
54+
55+
pos_lex_log_prob (:obj:`float`): The log (base 10) of the probability
56+
of the associated pos-lex pair values.
57+
Used for tie-breaking equally scored analyses.
58+
59+
lex_log_prob (:obj:`float`): The log (base 10) of the probability of
60+
the associated lex value.
61+
Used for tie-breaking equally scored analyses.
4462
"""
4563

64+
def __lt__(self, other):
65+
if self.score > other.score:
66+
return True
67+
elif self.score == other.score:
68+
if self.pos_lex_logprob > other.pos_lex_logprob:
69+
return True
70+
elif self.pos_lex_logprob == other.pos_lex_logprob:
71+
if self.lex_logprob > other.lex_logprob:
72+
return True
73+
elif self.lex_logprob == other.lex_logprob:
74+
return self.diac < other.diac
75+
76+
return False
77+
4678

4779
class DisambiguatedWord(namedtuple('DisambiguatedWord', ['word', 'analyses'])):
4880
"""A named tuple containing a word and a sorted list (from high to low

camel_tools/disambig/mle.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,17 @@ def _scored_analyses(self, word_dd):
204204

205205
scored_analyses = [ScoredAnalysis(p / max_prob, a)
206206
for a, p in zip(analyses, probabilities)]
207-
scored_analyses.sort(key=lambda w: (-w.score,
208-
len(w.analysis['bw']),
209-
w.analysis['diac']))
207+
208+
scored_analyses = [
209+
ScoredAnalysis(
210+
p / max_prob, # score
211+
a, # analysis
212+
a['diac'], # diac
213+
a.get('pos_lex_logprob', -99), # pos_lex_logprob
214+
a.get('lex_logprob', -99), # lex_logprob
215+
) for a, p in zip(analyses, probabilities)]
216+
217+
scored_analyses.sort()
210218

211219
return scored_analyses[0:self._top]
212220

0 commit comments

Comments
 (0)