Skip to content

Commit 98b00ea

Browse files
authored
Fix handling of small docs in coref (#28)
* Fix handling of small docs in coref Docs with one or zero tokens fail in the coref component. This doesn't have a fix yet, just a failing test. (There is also a test for the span resolver, which does not fail.) * Add example short doc to tests It might be better to include this optionally? On the other hand, since it should just be ignored in training, having it always there is more thorough. * Skip short docs There can be no coref prediction for docs with one token (or no tokens). Attempting to treat docs like that normally causes a mess with size inference, so instead they're skipped. In training, this just involves skipping the docs in the update step. This is simple due to the fake batching structure, since the batch doesn't have to be maintained. In inference, this just involves short-circuiting to an empty prediction. * Clean up retokenization test The retokenization test is hard-coded to the the training example because it manually merges some tokens, to make sure that the prediction and merge line up. It would probably be better to separate out the training data from the general example here, but for now narrowing the training data works.
1 parent 5cd4731 commit 98b00ea

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

spacy_experimental/coref/coref_component.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
145145
"""
146146
out = []
147147
for doc in docs:
148+
if len(doc) < 2:
149+
# no coref in docs with 0 or 1 token
150+
out.append([])
151+
continue
152+
148153
scores, idxs = self.model.predict([doc])
149154
# idxs is a list of mentions (start / end idxs)
150155
# each item in scores includes scores and a mapping from scores to mentions
@@ -232,6 +237,9 @@ def update(
232237
predicted docs in coref training.
233238
"""
234239
)
240+
if len(eg.predicted) < 2:
241+
# no prediction possible for docs of length 0 or 1
242+
continue
235243
preds, backprop = self.model.begin_update([eg.predicted])
236244
score_matrix, mention_idx = preds
237245
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)

spacy_experimental/coref/tests/test_coref.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def generate_train_data(prefix=DEFAULT_CLUSTER_PREFIX):
3737
}
3838
},
3939
),
40+
(
41+
# example short doc
42+
"ok",
43+
{"spans": {}}
44+
)
4045
]
4146
# fmt: on
4247
return data
@@ -83,11 +88,12 @@ def test_initialized(nlp):
8388

8489

8590
def test_initialized_short(nlp):
91+
# test that short or empty docs don't fail
8692
nlp.add_pipe("experimental_coref")
8793
nlp.initialize()
8894
assert nlp.pipe_names == ["experimental_coref"]
89-
text = "Hi there"
90-
doc = nlp(text)
95+
doc = nlp("Hi")
96+
doc = nlp("")
9197

9298

9399
def test_coref_serialization(nlp):
@@ -148,7 +154,8 @@ def test_overfitting_IO(nlp, train_data):
148154

149155
def test_tokenization_mismatch(nlp, train_data):
150156
train_examples = []
151-
for text, annot in train_data:
157+
# this is testing a specific test example, so just get the first doc
158+
for text, annot in train_data[0:1]:
152159
eg = Example.from_dict(nlp.make_doc(text), annot)
153160
ref = eg.reference
154161
char_spans = {}

spacy_experimental/coref/tests/test_span_resolver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ def test_not_initialized(nlp):
7979
with pytest.raises(ValueError, match="E109"):
8080
nlp(text)
8181

82+
def test_initialized_short(nlp):
83+
# docs with one or no tokens should not fail
84+
nlp.add_pipe("experimental_span_resolver")
85+
nlp.initialize()
86+
assert nlp.pipe_names == ["experimental_span_resolver"]
87+
nlp("hi")
88+
nlp("")
8289

8390
def test_span_resolver_serialization(nlp):
8491
# Test that the span resolver component can be serialized

0 commit comments

Comments
 (0)