Skip to content

Commit 3bc1845

Browse files
committed
fix: fixed test
1 parent 6e1b920 commit 3bc1845

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/modules/scoring/test_description.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import tempfile
2+
13
import numpy as np
24
import pytest
35

@@ -91,3 +93,17 @@ def test_description_scorer_cross_encoder(dataset, expected_prediction, multilab
9193
assert metadata is None
9294

9395
scorer.clear_cache()
96+
97+
with tempfile.TemporaryDirectory() as temp_dir:
98+
scorer.dump(temp_dir)
99+
100+
new_scorer = DescriptionScorer(
101+
cross_encoder_config="cross-encoder/ms-marco-MiniLM-L-6-v2", encoder_type="cross", temperature=0.3
102+
)
103+
new_scorer.load(temp_dir)
104+
105+
loaded_predictions = new_scorer.predict(test_utterances)
106+
107+
np.testing.assert_almost_equal(predictions, loaded_predictions, decimal=5)
108+
109+
new_scorer.clear_cache()

0 commit comments

Comments
 (0)