Skip to content

Commit bd6ebaf

Browse files
committed
ACTUALLY fix the span pruning issue
1 parent 7ee36ee commit bd6ebaf

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

tibert/bertcoref.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,14 +1088,14 @@ def mention_compatibility_score(
10881088
def pruned_mentions_indexs(
10891089
self, mention_scores: torch.Tensor, words_nb: int, top_mentions_nb: int
10901090
) -> torch.Tensor:
1091-
"""Prune mentions, keeping only the k non-overlapping best of them
1091+
"""Prune mentions, keeping only the k non-crossing best of them
10921092
10931093
The algorithm works as follows :
10941094
10951095
1. Sort mentions by individual scores
10961096
2. Accept mention in order, from best to worst score, until k of
10971097
them are accepted. A mention can only be accepted if no
1098-
previously accepted span is overlapping with it.
1098+
previously accepted span is crossing with it.
10991099
11001100
See section 5 of the E2ECoref paper and the C++ kernel in the
11011101
E2ECoref repository.
@@ -1116,10 +1116,12 @@ def pruned_mentions_indexs(
11161116

11171117
spans_idx = spans_indexs(list(range(words_nb)), self.config.max_span_size)
11181118

1119-
def spans_are_overlapping(
1120-
span1: Tuple[int, int], span2: Tuple[int, int]
1121-
) -> bool:
1122-
return not (span1[1] <= span2[0] or span1[0] >= span2[1])
1119+
def spans_are_crossing(span1: Tuple[int, int], span2: Tuple[int, int]) -> bool:
1120+
start1, end1 = (span1[0], span1[1] - 1)
1121+
start2, end2 = (span2[0], span2[1] - 1)
1122+
return (start1 < start2 and start2 <= end1 and end1 < end2) or (
1123+
start2 < start1 and start1 <= end2 and end2 < end1
1124+
)
11231125

11241126
_, sorted_indexs = torch.sort(mention_scores, 1, descending=True)
11251127
# TODO: what if we can't have top_mentions_nb mentions ??
@@ -1134,7 +1136,7 @@ def spans_are_overlapping(
11341136
span_index = int(sorted_indexs[b_i][s_j].item())
11351137
if not any(
11361138
[
1137-
spans_are_overlapping(
1139+
spans_are_crossing(
11381140
spans_idx[span_index], spans_idx[mention_idx]
11391141
)
11401142
for mention_idx in mention_indexs[-1]

0 commit comments

Comments
 (0)