@@ -1186,27 +1186,47 @@ def distance_between_spans(self, spans_nb: int, words_nb: int) -> torch.Tensor:
11861186 return dist
11871187
11881188 def closest_antecedents_indexs (
1189- self , spans_nb : int , words_nb : int , antecedents_nb : int
1189+ self ,
1190+ top_mentions_index : torch .Tensor ,
1191+ spans_nb : int ,
1192+ words_nb : int ,
1193+ antecedents_nb : int ,
11901194 ):
11911195 """Compute the indexs of the k closest mentions
11921196
1197+ :param top_mentions_index: a tensor of shape ``(b, m)``
11931198 :param spans_nb: number of spans in the sequence
11941199 :param words_nb: number of words in the sequence
11951200 :param antecedents_nb: number of antecedents to consider
1196- :return: a tensor of shape ``(p, a)``
1201+ :return: a tensor of shape ``(b, p, a)``
11971202 """
1203+ device = next (self .parameters ()).device
1204+ b , _ = top_mentions_index .shape
1205+ p = spans_nb
1206+ a = antecedents_nb
1207+
11981208 dist = self .distance_between_spans (spans_nb , words_nb )
1199- assert dist .shape == (spans_nb , spans_nb )
1209+ assert dist .shape == (p , p )
12001210
12011211 # when the distance between a span and a possible antecedent
1202- # is 0 or negative, it means the possible antecedents is after
1203- # the span. Therefore, it can't be an antecedents . We set
1204- # those distances to Inf for torch.topk usage just after
1212+ # is 0 or negative, it means the possible antecedent is after
1213+ # the span. Therefore, it can't be an antecedent . We set those
1214+ # distances to Inf for torch.topk usage just after
12051215 dist [dist <= 0 ] = float ("Inf" )
12061216
1217+ # discard pruned non-top mentions using the same technique as
1218+ # above
1219+ all_indices = torch .tile (torch .arange (spans_nb ), (b , 1 )).to (device )
1220+ pruned_mask = ~ torch .isin (all_indices , top_mentions_index )
1221+ assert pruned_mask .shape == (b , p )
1222+ dist = torch .tile (dist , (b , 1 , 1 ))
1223+ dist [pruned_mask , :] = float ("Inf" ) # remove pruned lines
1224+ dist .swapaxes (1 , 2 )[pruned_mask , :] = float ("Inf" ) # remove pruned cols
1225+ assert dist .shape == (b , p , p )
1226+
12071227 # top-k closest antecedents
12081228 _ , close_indexs = torch .topk (- dist , antecedents_nb )
1209- assert close_indexs .shape == (spans_nb , antecedents_nb )
1229+ assert close_indexs .shape == (b , p , a )
12101230
12111231 return close_indexs
12121232
@@ -1455,9 +1475,8 @@ def forward(
14551475 # antecedents for each spans
14561476 antecedents_nb = a = min (self .config .antecedents_nb , spans_nb )
14571477 antecedents_index = self .closest_antecedents_indexs (
1458- spans_nb , words_nb , antecedents_nb
1478+ top_mentions_index , spans_nb , words_nb , antecedents_nb
14591479 )
1460- antecedents_index = torch .tile (antecedents_index , (batch_size , 1 , 1 ))
14611480 assert antecedents_index .shape == (b , p , a )
14621481
14631482 # -- mention compatibility scores computation --
0 commit comments