@@ -158,7 +158,7 @@ def prepared_document(
158158 except ValueError :
159159 continue
160160 new_chain .append (
161- Mention (tokens [start_idx : end_idx + 1 ], start_idx , end_idx )
161+ Mention (tokens [start_idx : end_idx + 1 ], start_idx , end_idx + 1 )
162162 )
163163 if len (new_chain ) > 0 :
164164 new_chains .append (new_chain )
@@ -193,7 +193,7 @@ def from_wpieced_to_tokenized(
193193 new_end_idx = wp_to_token [mention .end_idx ]
194194 new_chain .append (
195195 Mention (
196- tokens [new_start_idx : new_end_idx + 1 ],
196+ tokens [new_start_idx : new_end_idx ],
197197 new_start_idx ,
198198 new_end_idx ,
199199 )
@@ -227,7 +227,7 @@ def from_labels(
227227 # singleton cluster
228228 if mention_labels [i ] == 1 :
229229 start_idx , end_idx = spans_idx [i ]
230- mention_tokens = tokens [start_idx : end_idx + 1 ]
230+ mention_tokens = tokens [start_idx : end_idx ]
231231 chains .append ([Mention (mention_tokens , start_idx , end_idx )])
232232 already_visited_mentions .append (i )
233233
@@ -237,7 +237,7 @@ def from_labels(
237237 continue
238238
239239 start_idx , end_idx = spans_idx [i ]
240- mention_tokens = tokens [start_idx : end_idx + 1 ]
240+ mention_tokens = tokens [start_idx : end_idx ]
241241 chain = [Mention (mention_tokens , start_idx , end_idx )]
242242
243243 for j , label in enumerate (mlabels ):
@@ -246,7 +246,7 @@ def from_labels(
246246 continue
247247
248248 start_idx , end_idx = spans_idx [j ]
249- mention_tokens = tokens [start_idx : end_idx + 1 ]
249+ mention_tokens = tokens [start_idx : end_idx ]
250250 chain .append (Mention (mention_tokens , start_idx , end_idx ))
251251 already_visited_mentions .append (j )
252252
@@ -426,9 +426,9 @@ def from_conll2012_file(
426426
427427 if mention_is_ending :
428428 mention_start_idx = open_mentions [chain_id ].pop ()
429- mention_end_idx = len (document_tokens ) - 1
429+ mention_end_idx = len (document_tokens )
430430 mention = Mention (
431- document_tokens [mention_start_idx : mention_end_idx + 1 ],
431+ document_tokens [mention_start_idx : mention_end_idx ],
432432 mention_start_idx ,
433433 mention_end_idx ,
434434 )
@@ -665,7 +665,7 @@ def coreference_documents(
665665 # the antecedent is the dummy mention : maybe we have
666666 # a one-mention chain ?
667667 if top_antecedent_idx == antecedents_nb - 1 :
668- if self .top_mentions_scores [b_i ][m_j ].item () > 0.0 :
668+ if float ( self .top_mentions_scores [b_i ][m_j ].item () ) > 0.0 :
669669 G .add_node (span_mention )
670670 continue
671671
@@ -907,7 +907,7 @@ def distance_between_spans(self, spans_nb: int, seq_size: int) -> torch.Tensor:
907907
908908 # distance between a span and its antecedent is defined to be
909909 # the span start index minus the antecedent span end index
910- dist = start_end_idx_combinations [:, 0 ] - start_end_idx_combinations [:, 1 ]
910+ dist = start_end_idx_combinations [:, 0 ] - start_end_idx_combinations [:, 1 ] + 1
911911 assert dist .shape == (p * p ,)
912912 dist = dist .reshape (spans_nb , spans_nb )
913913
0 commit comments