Skip to content

Commit d6729c1

Browse files
committed
fix wrong mention_score being assigned to antecedent mentions
1 parent 374b123 commit d6729c1

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

tibert/bertcoref.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -833,8 +833,8 @@ class BertCoreferenceResolutionOutput:
833833
# (batch_size, top_mentions_nb)
834834
top_mentions_index: torch.Tensor
835835

836-
# (batch_size, top_mentions_nb)
837-
top_mentions_scores: torch.Tensor
836+
# (batch_size, spans_nb)
837+
mentions_scores: torch.Tensor
838838

839839
# (batch_size, top_mentions_nb, antecedents_nb)
840840
top_antecedents_index: torch.Tensor
@@ -873,38 +873,40 @@ def coreference_documents(
873873

874874
G = nx.Graph()
875875
for m_j in range(top_mentions_nb):
876-
span_idx = int(self.top_mentions_index[b_i][m_j].item())
877-
span_coords = spans_idx[span_idx]
878-
879-
top_antecedent_idx = int(antecedents_idx[b_i][m_j].item())
876+
span_i = int(self.top_mentions_index[b_i][m_j].item())
877+
span_coords = spans_idx[span_i]
880878

881-
mention_score = float(self.top_mentions_scores[b_i][m_j].item())
879+
mention_score = float(self.mentions_scores[b_i][span_i].item())
882880
span_mention = Mention(
883881
tokens[b_i][span_coords[0] : span_coords[1]],
884882
span_coords[0],
885883
span_coords[1],
886884
mention_score=mention_score,
887885
)
888886

887+
# index of the best antecedent in self.top_antecedent_index
888+
top_antecedent_idx = int(antecedents_idx[b_i][m_j].item())
889+
889890
# the antecedent is the dummy mention : maybe we have
890891
# a one-mention chain ?
891892
if top_antecedent_idx == antecedents_nb - 1:
892-
if float(self.top_mentions_scores[b_i][m_j].item()) > 0.0:
893+
if float(self.mentions_scores[b_i][span_i].item()) > 0.0:
893894
G.add_node(span_mention)
894895
continue
895896

896-
antecedent_idx = int(
897+
antecedent_span_i = int(
897898
self.top_antecedents_index[b_i][m_j][top_antecedent_idx].item()
898899
)
900+
antecedent_coords = spans_idx[antecedent_span_i]
899901

900-
antecedent_coords = spans_idx[antecedent_idx]
901-
902-
mention_score = float(self.top_mentions_scores[b_i][m_j].item())
902+
antecedent_mention_score = float(
903+
self.mentions_scores[b_i][antecedent_span_i].item()
904+
)
903905
antecedent_mention = Mention(
904906
tokens[b_i][antecedent_coords[0] : antecedent_coords[1]],
905907
antecedent_coords[0],
906908
antecedent_coords[1],
907-
mention_score=mention_score,
909+
mention_score=antecedent_mention_score,
908910
)
909911

910912
G.add_node(antecedent_mention)
@@ -1510,7 +1512,7 @@ def forward(
15101512
return BertCoreferenceResolutionOutput(
15111513
final_scores,
15121514
top_mentions_index,
1513-
top_mention_scores,
1515+
mention_scores,
15141516
top_antecedents_index,
15151517
self.config.max_span_size,
15161518
loss=loss,

0 commit comments

Comments
 (0)