@@ -833,8 +833,8 @@ class BertCoreferenceResolutionOutput:
833833 # (batch_size, top_mentions_nb)
834834 top_mentions_index : torch .Tensor
835835
836- # (batch_size, top_mentions_nb )
837- top_mentions_scores : torch .Tensor
836+ # (batch_size, spans_nb )
837+ mentions_scores : torch .Tensor
838838
839839 # (batch_size, top_mentions_nb, antecedents_nb)
840840 top_antecedents_index : torch .Tensor
@@ -873,38 +873,40 @@ def coreference_documents(
873873
874874 G = nx .Graph ()
875875 for m_j in range (top_mentions_nb ):
876- span_idx = int (self .top_mentions_index [b_i ][m_j ].item ())
877- span_coords = spans_idx [span_idx ]
878-
879- top_antecedent_idx = int (antecedents_idx [b_i ][m_j ].item ())
876+ span_i = int (self .top_mentions_index [b_i ][m_j ].item ())
877+ span_coords = spans_idx [span_i ]
880878
881- mention_score = float (self .top_mentions_scores [b_i ][m_j ].item ())
879+ mention_score = float (self .mentions_scores [b_i ][span_i ].item ())
882880 span_mention = Mention (
883881 tokens [b_i ][span_coords [0 ] : span_coords [1 ]],
884882 span_coords [0 ],
885883 span_coords [1 ],
886884 mention_score = mention_score ,
887885 )
888886
887+ # index of the best antecedent in self.top_antecedent_index
888+ top_antecedent_idx = int (antecedents_idx [b_i ][m_j ].item ())
889+
889890 # the antecedent is the dummy mention : maybe we have
890891 # a one-mention chain ?
891892 if top_antecedent_idx == antecedents_nb - 1 :
892- if float (self .top_mentions_scores [b_i ][m_j ].item ()) > 0.0 :
893+ if float (self .mentions_scores [b_i ][span_i ].item ()) > 0.0 :
893894 G .add_node (span_mention )
894895 continue
895896
896- antecedent_idx = int (
897+ antecedent_span_i = int (
897898 self .top_antecedents_index [b_i ][m_j ][top_antecedent_idx ].item ()
898899 )
900+ antecedent_coords = spans_idx [antecedent_span_i ]
899901
900- antecedent_coords = spans_idx [ antecedent_idx ]
901-
902- mention_score = float ( self . top_mentions_scores [ b_i ][ m_j ]. item () )
902+ antecedent_mention_score = float (
903+ self . mentions_scores [ b_i ][ antecedent_span_i ]. item ()
904+ )
903905 antecedent_mention = Mention (
904906 tokens [b_i ][antecedent_coords [0 ] : antecedent_coords [1 ]],
905907 antecedent_coords [0 ],
906908 antecedent_coords [1 ],
907- mention_score = mention_score ,
909+ mention_score = antecedent_mention_score ,
908910 )
909911
910912 G .add_node (antecedent_mention )
@@ -1510,7 +1512,7 @@ def forward(
15101512 return BertCoreferenceResolutionOutput (
15111513 final_scores ,
15121514 top_mentions_index ,
1513- top_mention_scores ,
1515+ mention_scores ,
15141516 top_antecedents_index ,
15151517 self .config .max_span_size ,
15161518 loss = loss ,
0 commit comments