Skip to content

Commit fb51ac0

Browse files
committed
coref_labels now correctly ignore mention longer than max_span_size
1 parent 58bca82 commit fb51ac0

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

tests/test_bertcoref.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def bert_tokenizer() -> BertTokenizerFast:
1313
# we suppress the `function_scoped_fixture` health check since we want
1414
# to execute the `bert_tokenizer` fixture only once.
1515
@settings(deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture])
16-
@given(doc=coref_docs(min_size=5, max_size=10))
16+
@given(doc=coref_docs(min_size=5, max_size=10, max_span_size=4))
1717
def test_doc_is_reconstructed(
1818
doc: CoreferenceDocument, bert_tokenizer: BertTokenizerFast
1919
):

tibert/bertcoref.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,23 @@ def coref_labels(self, max_span_size: int) -> List[List[int]]:
9494

9595
# spans in a coref chain : mark all antecedents
9696
for chain in self.coref_chains:
97+
# mentions can be longer than max_span_size. We filter
98+
# these mentions so that they do not appear in the labels.
99+
chain = [m for m in chain if len(m.tokens) <= max_span_size]
100+
if len(chain) == 0:
101+
continue
97102
for mention in chain:
98-
try:
99-
mention_idx = spans_idx[(mention.start_idx, mention.end_idx)]
100-
for other_mention in chain:
101-
if other_mention == mention:
102-
continue
103-
other_mention_idx = spans_idx[
104-
(other_mention.start_idx, other_mention.end_idx)
105-
]
106-
labels[mention_idx][other_mention_idx] = 1
107-
# ValueError happens if the mention does not exist in
108-
# spans_idx. This is possible since the mention can be
109-
# larger than max_span_size
110-
except ValueError:
111-
continue
103+
mention_idx = spans_idx[(mention.start_idx, mention.end_idx)]
104+
for other_mention in chain:
105+
if other_mention == mention:
106+
continue
107+
key = (other_mention.start_idx, other_mention.end_idx)
108+
if not key in spans_idx:
109+
continue
110+
other_mention_idx = spans_idx[
111+
(other_mention.start_idx, other_mention.end_idx)
112+
]
113+
labels[mention_idx][other_mention_idx] = 1
112114

113115
# spans without preceding mentions : mark preceding mention to
114116
# be the null span

0 commit comments

Comments
 (0)