Skip to content

Commit 33286c9

Browse files
authored
Fix issue with resolving final token in SpanResolver (#27)
* Fix issue with resolving final token in SpanResolver The SpanResolver seems unable to include the final token in a Doc in output spans. It will even produce empty spans instead of doing so. This makes changes so that within the model span end indices are treated as inclusive, and converts them back to exclusive when annotating docs. This has been tested to work, though an automated test should be added. * Modify tests so last token is in a mention Running the modify tests without the changes from the previous commit, they fail. This demonstrates and clarifies the bug. * Add / rearrange comments
1 parent 98b00ea commit 33286c9

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

spacy_experimental/coref/pytorch_span_resolver_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def forward(
5353
Returns:
5454
torch.Tensor: span start/end scores, (n_heads x n_words x 2)
5555
"""
56+
5657
# If we don't receive heads, return empty
5758
device = heads_ids.device
5859
if heads_ids.nelement() == 0:

spacy_experimental/coref/span_resolver_component.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
176176
"""
177177
for doc, clusters in zip(docs, clusters_by_doc):
178178
for ii, cluster in enumerate(clusters, 1):
179-
spans = [doc[int(mm[0]) : int(mm[1])] for mm in cluster]
179+
# Note the +1, since model end indices are inclusive
180+
spans = [doc[int(mm[0]) : int(mm[1]) + 1] for mm in cluster]
180181
doc.spans[f"{self.output_prefix}_{ii}"] = spans
181182

182183
def update(
@@ -274,9 +275,12 @@ def get_loss(
274275

275276
# NOTE This is doing fake batching, and should always get a list of one example
276277
assert len(list(examples)) == 1, "Only fake batching is supported."
277-
# starts and ends are gold starts and ends (Ints1d)
278-
# span_scores is a Floats3d. What are the axes? mention x token x start/end
278+
279+
# NOTE Within this component, end token indices are *inclusive*. This
280+
# is different than normal Python/spaCy representations, but has the
281+
# advantage that the set of possible start and end indices is the same.
279282
for eg in examples:
283+
# starts and ends are gold starts and ends (Ints1d)
280284
starts = []
281285
ends = []
282286
keeps = []
@@ -296,11 +300,12 @@ def get_loss(
296300
)
297301
continue
298302
starts.append(span.start)
299-
ends.append(span.end)
303+
ends.append(span.end - 1)
300304
keeps.append(sidx - 1)
301305

302306
starts_xp = self.model.ops.xp.asarray(starts)
303307
ends_xp = self.model.ops.xp.asarray(ends)
308+
# span_scores is a Floats3d. Axes: mention x token x start/end
304309
start_scores = span_scores[:, :, 0][keeps]
305310
end_scores = span_scores[:, :, 1][keeps]
306311

spacy_experimental/coref/tests/test_span_resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def generate_train_data(
2222
# fmt: off
2323
data = [
2424
(
25-
"John Smith picked up the red ball and he threw it away.",
25+
"John Smith picked up the red ball and he threw it",
2626
{
2727
"spans": {
2828
f"{output_prefix}_1": [

0 commit comments

Comments
 (0)