Skip to content

Commit 2af31a8

Browse files
Bugfix textcat reproducibility on GPU (#6411)
* add seed argument to ParametricAttention layer * bump thinc to 7.4.3 * set thinc version range Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
1 parent cdca44a commit 2af31a8

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

spacy/_ml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
646646
SpacyVectors
647647
>> flatten_add_lengths
648648
>> with_getitem(0, Affine(width, pretrained_dims))
649-
>> ParametricAttention(width)
649+
>> ParametricAttention(width, seed=100)
650650
>> Pooling(sum_pool)
651651
>> Residual(ReLu(width, width)) ** 2
652652
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
@@ -688,7 +688,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
688688
cnn_model = (
689689
tok2vec
690690
>> flatten_add_lengths
691-
>> ParametricAttention(width)
691+
>> ParametricAttention(width, seed=99)
692692
>> Pooling(sum_pool)
693693
>> Residual(zero_init(Maxout(width, width)))
694694
>> zero_init(Affine(nr_class, width, drop_factor=0.0))

spacy/tests/regression/test_issue6177.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ def test_issue6177():
1111
# NOTE: no need to transform this code to v3 when 'master' is merged into 'develop'.
1212
# A similar test exists already for v3: test_issue5551
1313
# This is just a backport
14-
1514
results = []
1615
for i in range(3):
1716
fix_random_seed(0)
@@ -24,12 +23,15 @@ def test_issue6177():
2423
nlp.add_pipe(textcat)
2524
for label in set(example[1]["cats"]):
2625
textcat.add_label(label)
27-
nlp.begin_training()
26+
# Train
27+
optimizer = nlp.begin_training()
28+
text, annots = example
29+
nlp.update([text], [annots], sgd=optimizer)
2830
# Store the result of each iteration
29-
result = textcat.model.predict([nlp.make_doc(example[0])])
31+
result = textcat.model.predict([nlp.make_doc(text)])
3032
results.append(list(result[0]))
3133

3234
# All results should be the same because of the fixed seed
3335
assert len(results) == 3
3436
assert results[0] == results[1]
35-
assert results[0] == results[2]
37+
assert results[0] == results[2]

0 commit comments

Comments
 (0)