Skip to content

Commit 794d18e

Browse files
committed
Fixed usage of updated ScoredAnalysis.
1 parent 7a0753a commit 794d18e

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

camel_tools/disambig/bert/unfactored.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,12 @@ def _scored_analyses(self, word_dd, prediction):
468468
if len(analyses) == 0:
469469
# If the word is not found in the analyzer,
470470
# return the predictions from BERT
471-
return [ScoredAnalysis(0, bert_analysis)]
471+
return [ScoredAnalysis(0, # score
472+
bert_analysis, # analysis
473+
bert_analysis['diac'], # diac
474+
-99, # pos_lex_logprob
475+
-99, # lex_logprob
476+
)]
472477

473478
scored = [(self._scorer(a,
474479
bert_analysis,

camel_tools/disambig/mle.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,17 @@ def _scored_analyses(self, word_dd):
188188

189189
max_score = max([s[0] for s in scored])
190190

191-
scored_analyses = [ScoredAnalysis(s[0] / max_score, s[1])
192-
for s in scored]
191+
if max_score == 0:
192+
max_score = 1
193+
194+
scored_analyses = [
195+
ScoredAnalysis(
196+
s / max_score, # score
197+
a, # analysis
198+
a['diac'], # diac
199+
a.get('pos_lex_logprob', -99), # pos_lex_logprob
200+
a.get('lex_logprob', -99), # lex_logprob
201+
) for s, a in scored]
193202

194203
return scored_analyses[0:self._top]
195204

@@ -202,9 +211,6 @@ def _scored_analyses(self, word_dd):
202211
probabilities = [10 ** _get_pos_lex_logprob(a) for a in analyses]
203212
max_prob = max(probabilities)
204213

205-
scored_analyses = [ScoredAnalysis(p / max_prob, a)
206-
for a, p in zip(analyses, probabilities)]
207-
208214
scored_analyses = [
209215
ScoredAnalysis(
210216
p / max_prob, # score

0 commit comments

Comments
 (0)