Skip to content

Commit bccb36e

Browse files
committed
closest_antecedents_indexs now only select in non-pruned mentions
1 parent bd6ebaf commit bccb36e

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

tibert/bertcoref.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,27 +1186,47 @@ def distance_between_spans(self, spans_nb: int, words_nb: int) -> torch.Tensor:
11861186
return dist
11871187

11881188
def closest_antecedents_indexs(
1189-
self, spans_nb: int, words_nb: int, antecedents_nb: int
1189+
self,
1190+
top_mentions_index: torch.Tensor,
1191+
spans_nb: int,
1192+
words_nb: int,
1193+
antecedents_nb: int,
11901194
):
11911195
"""Compute the indexs of the k closest mentions
11921196
1197+
:param top_mentions_index: a tensor of shape ``(b, m)``
11931198
:param spans_nb: number of spans in the sequence
11941199
:param words_nb: number of words in the sequence
11951200
:param antecedents_nb: number of antecedents to consider
1196-
:return: a tensor of shape ``(p, a)``
1201+
:return: a tensor of shape ``(b, p, a)``
11971202
"""
1203+
device = next(self.parameters()).device
1204+
b, _ = top_mentions_index.shape
1205+
p = spans_nb
1206+
a = antecedents_nb
1207+
11981208
dist = self.distance_between_spans(spans_nb, words_nb)
1199-
assert dist.shape == (spans_nb, spans_nb)
1209+
assert dist.shape == (p, p)
12001210

12011211
# when the distance between a span and a possible antecedent
1202-
# is 0 or negative, it means the possible antecedents is after
1203-
# the span. Therefore, it can't be an antecedents. We set
1204-
# those distances to Inf for torch.topk usage just after
1212+
# is 0 or negative, it means the possible antecedent is after
1213+
# the span. Therefore, it can't be an antecedent. We set those
1214+
# distances to Inf for torch.topk usage just after
12051215
dist[dist <= 0] = float("Inf")
12061216

1217+
# discard pruned non-top mentions using the same technique as
1218+
# above
1219+
all_indices = torch.tile(torch.arange(spans_nb), (b, 1)).to(device)
1220+
pruned_mask = ~torch.isin(all_indices, top_mentions_index)
1221+
assert pruned_mask.shape == (b, p)
1222+
dist = torch.tile(dist, (b, 1, 1))
1223+
dist[pruned_mask, :] = float("Inf") # remove pruned lines
1224+
dist.swapaxes(1, 2)[pruned_mask, :] = float("Inf") # remove pruned cols
1225+
assert dist.shape == (b, p, p)
1226+
12071227
# top-k closest antecedents
12081228
_, close_indexs = torch.topk(-dist, antecedents_nb)
1209-
assert close_indexs.shape == (spans_nb, antecedents_nb)
1229+
assert close_indexs.shape == (b, p, a)
12101230

12111231
return close_indexs
12121232

@@ -1455,9 +1475,8 @@ def forward(
14551475
# antecedents for each spans
14561476
antecedents_nb = a = min(self.config.antecedents_nb, spans_nb)
14571477
antecedents_index = self.closest_antecedents_indexs(
1458-
spans_nb, words_nb, antecedents_nb
1478+
top_mentions_index, spans_nb, words_nb, antecedents_nb
14591479
)
1460-
antecedents_index = torch.tile(antecedents_index, (batch_size, 1, 1))
14611480
assert antecedents_index.shape == (b, p, a)
14621481

14631482
# -- mention compatibility scores computation --

0 commit comments

Comments
 (0)