@@ -94,21 +94,23 @@ def coref_labels(self, max_span_size: int) -> List[List[int]]:
9494
9595 # spans in a coref chain : mark all antecedents
9696 for chain in self .coref_chains :
97+ # mentions can be longer than max_span_size. We filter
98+ # these mentions so that they do not appear in the labels.
99+ chain = [m for m in chain if len (m .tokens ) <= max_span_size ]
100+ if len (chain ) == 0 :
101+ continue
97102 for mention in chain :
98- try :
99- mention_idx = spans_idx [(mention .start_idx , mention .end_idx )]
100- for other_mention in chain :
101- if other_mention == mention :
102- continue
103- other_mention_idx = spans_idx [
104- (other_mention .start_idx , other_mention .end_idx )
105- ]
106- labels [mention_idx ][other_mention_idx ] = 1
107- # ValueError happens if the mention does not exist in
108- # spans_idx. This is possible since the mention can be
109- # larger than max_span_size
110- except ValueError :
111- continue
103+ mention_idx = spans_idx [(mention .start_idx , mention .end_idx )]
104+ for other_mention in chain :
105+ if other_mention == mention :
106+ continue
107+ key = (other_mention .start_idx , other_mention .end_idx )
108+ if not key in spans_idx :
109+ continue
110+ other_mention_idx = spans_idx [
111+ (other_mention .start_idx , other_mention .end_idx )
112+ ]
113+ labels [mention_idx ][other_mention_idx ] = 1
112114
113115 # spans without preceding mentions : mark preceding mention to
114116 # be the null span
0 commit comments