Skip to content

Commit 372a908

Browse files
authored
Fix spancat-singlelabel score (#12469)
* debug argmax sort and add span scores * add missing tests for spanscores
1 parent dba4e7b commit 372a908

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

spacy/pipeline/spancat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ def _make_span_group_singlelabel(
726726
if not allow_overlap:
727727
# Get the probabilities
728728
sort_idx = (argmax_scores.squeeze() * -1).argsort()
729+
argmax_scores = argmax_scores[sort_idx]
729730
predicted = predicted[sort_idx]
730731
indices = indices[sort_idx]
731732
keeps = keeps[sort_idx]
@@ -748,4 +749,5 @@ def _make_span_group_singlelabel(
748749
attrs_scores.append(argmax_scores[i])
749750
spans.append(Span(doc, start, end, label=self.labels[label]))
750751

752+
spans.attrs["scores"] = numpy.array(attrs_scores)
751753
return spans

spacy/tests/pipeline/test_spancat.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,19 @@ def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results):
190190
spangroup = spancat._make_span_group_singlelabel(
191191
doc, indices, scores, allow_overlap
192192
)
193-
assert len(spangroup) == nr_results
194193
if threshold > 0.4:
195194
if allow_overlap:
196195
assert spangroup[0].text == "London"
197196
assert spangroup[0].label_ == "City"
197+
assert_almost_equal(0.6, spangroup.attrs["scores"][0], 5)
198198
assert spangroup[1].text == "Greater London"
199199
assert spangroup[1].label_ == "GreatCity"
200-
200+
assert spangroup.attrs["scores"][1] == 0.9
201+
assert_almost_equal(0.9, spangroup.attrs["scores"][1], 5)
201202
else:
202203
assert spangroup[0].text == "Greater London"
203204
assert spangroup[0].label_ == "GreatCity"
205+
assert spangroup.attrs["scores"][0] == 0.9
204206
else:
205207
if allow_overlap:
206208
assert spangroup[0].text == "Greater"
@@ -256,22 +258,32 @@ def test_make_spangroup_negative_label():
256258
assert len(spangroup_single) == 2
257259
assert spangroup_single[0].text == "Greater"
258260
assert spangroup_single[0].label_ == "City"
261+
assert_almost_equal(0.4, spangroup_single.attrs["scores"][0], 5)
259262
assert spangroup_single[1].text == "Greater London"
260263
assert spangroup_single[1].label_ == "GreatCity"
264+
assert spangroup_single.attrs["scores"][1] == 0.9
265+
assert_almost_equal(0.9, spangroup_single.attrs["scores"][1], 5)
261266

262267
assert len(spangroup_multi) == 6
263268
assert spangroup_multi[0].text == "Greater"
264269
assert spangroup_multi[0].label_ == "City"
270+
assert_almost_equal(0.4, spangroup_multi.attrs["scores"][0], 5)
265271
assert spangroup_multi[1].text == "Greater"
266272
assert spangroup_multi[1].label_ == "Person"
273+
assert_almost_equal(0.3, spangroup_multi.attrs["scores"][1], 5)
267274
assert spangroup_multi[2].text == "London"
268275
assert spangroup_multi[2].label_ == "City"
276+
assert_almost_equal(0.6, spangroup_multi.attrs["scores"][2], 5)
269277
assert spangroup_multi[3].text == "London"
270278
assert spangroup_multi[3].label_ == "GreatCity"
279+
assert_almost_equal(0.4, spangroup_multi.attrs["scores"][3], 5)
271280
assert spangroup_multi[4].text == "Greater London"
272281
assert spangroup_multi[4].label_ == "Thing"
282+
assert spangroup_multi[4].text == "Greater London"
283+
assert_almost_equal(0.8, spangroup_multi.attrs["scores"][4], 5)
273284
assert spangroup_multi[5].text == "Greater London"
274285
assert spangroup_multi[5].label_ == "GreatCity"
286+
assert_almost_equal(0.9, spangroup_multi.attrs["scores"][5], 5)
275287

276288

277289
def test_ngram_suggester(en_tokenizer):

0 commit comments

Comments
 (0)