@@ -1088,14 +1088,14 @@ def mention_compatibility_score(
10881088 def pruned_mentions_indexs (
10891089 self , mention_scores : torch .Tensor , words_nb : int , top_mentions_nb : int
10901090 ) -> torch .Tensor :
1091- """Prune mentions, keeping only the k non-overlapping best of them
1091+ """Prune mentions, keeping only the k non-crossing best of them
10921092
10931093 The algorithm works as follows :
10941094
10951095 1. Sort mentions by individual scores
10961096 2. Accept mention in order, from best to worst score, until k of
10971097 them are accepted. A mention can only be accepted if no
1098- previously accepted span is overlapping with it.
1098+ previously accepted span is crossing with it.
10991099
11001100 See section 5 of the E2ECoref paper and the C++ kernel in the
11011101 E2ECoref repository.
@@ -1116,10 +1116,12 @@ def pruned_mentions_indexs(
11161116
11171117 spans_idx = spans_indexs (list (range (words_nb )), self .config .max_span_size )
11181118
1119- def spans_are_overlapping (
1120- span1 : Tuple [int , int ], span2 : Tuple [int , int ]
1121- ) -> bool :
1122- return not (span1 [1 ] <= span2 [0 ] or span1 [0 ] >= span2 [1 ])
1119+ def spans_are_crossing (span1 : Tuple [int , int ], span2 : Tuple [int , int ]) -> bool :
1120+ start1 , end1 = (span1 [0 ], span1 [1 ] - 1 )
1121+ start2 , end2 = (span2 [0 ], span2 [1 ] - 1 )
1122+ return (start1 < start2 and start2 <= end1 and end1 < end2 ) or (
1123+ start2 < start1 and start1 <= end2 and end2 < end1
1124+ )
11231125
11241126 _ , sorted_indexs = torch .sort (mention_scores , 1 , descending = True )
11251127 # TODO: what if we can't have top_mentions_nb mentions ??
@@ -1134,7 +1136,7 @@ def spans_are_overlapping(
11341136 span_index = int (sorted_indexs [b_i ][s_j ].item ())
11351137 if not any (
11361138 [
1137- spans_are_overlapping (
1139+ spans_are_crossing (
11381140 spans_idx [span_index ], spans_idx [mention_idx ]
11391141 )
11401142 for mention_idx in mention_indexs [- 1 ]
0 commit comments