Skip to content

Commit ee03bb1

Browse files
committed
Increase tolerance in retrieval transformer test and random seed (#1007)
1 parent 5d0e090 commit ee03bb1

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/unit/tf/transformers/test_block.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55
import tensorflow as tf
6+
from tensorflow.keras.utils import set_random_seed
67
from transformers import BertConfig
78

89
import merlin.models.tf as mm
@@ -27,6 +28,7 @@ def test_import():
2728

2829
@pytest.mark.parametrize("run_eagerly", [True])
2930
def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly):
31+
set_random_seed(42)
3032

3133
sequence_testing_data.schema = sequence_testing_data.schema.select_by_tag(
3234
Tags.SEQUENCE
@@ -78,7 +80,7 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly):
7880
assert list(item_embeddings.shape) == [51997, d_model]
7981
predicitons_2 = np.dot(query_embeddings, item_embeddings.T)
8082

81-
np.testing.assert_allclose(predictions, predicitons_2, atol=1e-7)
83+
np.testing.assert_allclose(predictions, predicitons_2, atol=1e-6)
8284

8385

8486
def test_transformer_encoder():

0 commit comments

Comments
 (0)