Skip to content

Commit 25e9350

Browse files
committed
try to fix file exists error
1 parent 10ccf58 commit 25e9350

File tree

3 files changed

+10
-18
lines changed

3 files changed

+10
-18
lines changed

tests/modules/decision/test_threshold.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from autointent.exceptions import MismatchNumClassesError
55
from autointent.modules import ThresholdDecision
6-
from tests.conftest import setup_environment
76

87

98
@pytest.mark.parametrize(
@@ -43,17 +42,16 @@ def test_fails_on_wrong_n_classes_fit(multiclass_fit_data):
4342

4443

4544
@pytest.mark.parametrize("fit_fixture", ["multiclass_fit_data", "multilabel_fit_data"])
46-
def test_dump_load(fit_fixture, request):
45+
def test_dump_load(fit_fixture, request, tmp_path):
4746
fit_data = request.getfixturevalue(fit_fixture)
4847
predictor = ThresholdDecision(thresh=0.3)
4948
predictor.fit(*fit_data)
5049
predictions = predictor.predict(fit_data[0])
5150

52-
path = setup_environment() / "threshold_module"
53-
predictor.dump(path)
51+
predictor.dump(tmp_path)
5452
del predictor
5553

56-
predictor = ThresholdDecision.load(path)
54+
predictor = ThresholdDecision.load(tmp_path)
5755

5856
assert hasattr(predictor, "thresh")
5957
assert predictor.thresh is not None

tests/modules/decision/test_tunable.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from autointent.exceptions import MismatchNumClassesError
55
from autointent.modules import TunableDecision
6-
from tests.conftest import setup_environment
76

87

98
@pytest.mark.parametrize(
@@ -41,17 +40,16 @@ def test_fails_on_wrong_n_classes_predict(multiclass_fit_data):
4140

4241

4342
@pytest.mark.parametrize("fit_fixture", ["multiclass_fit_data", "multilabel_fit_data"])
44-
def test_dump_load(fit_fixture, request):
43+
def test_dump_load(fit_fixture, request, tmp_path):
4544
fit_data = request.getfixturevalue(fit_fixture)
4645
predictor = TunableDecision()
4746
predictor.fit(*fit_data)
4847
predictions = predictor.predict(fit_data[0])
4948

50-
path = setup_environment() / "tunable_module"
51-
predictor.dump(path)
49+
predictor.dump(tmp_path)
5250
del predictor
5351

54-
predictor = TunableDecision.load(path)
52+
predictor = TunableDecision.load(tmp_path)
5553
assert hasattr(predictor, "thresh")
5654
assert predictor.thresh is not None
5755
assert isinstance(predictor.thresh, np.ndarray)
Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import shutil
1+
from pathlib import Path
22

33
from autointent.modules.embedding import RetrievalAimedEmbedding
4-
from tests.conftest import setup_environment
54

65

76
def test_get_assets_returns_correct_artifact():
@@ -10,20 +9,17 @@ def test_get_assets_returns_correct_artifact():
109
assert artifact.config.model_name == "sergeyzh/rubert-tiny-turbo"
1110

1211

13-
def test_dump_and_load_preserves_model_state():
14-
project_dir = setup_environment()
12+
def test_dump_and_load_preserves_model_state(tmp_path: Path):
1513
module = RetrievalAimedEmbedding(k=5, embedder_config="sergeyzh/rubert-tiny-turbo")
1614

1715
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
1816
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
1917
module.fit(utterances, labels)
2018
predictions = module.predict(utterances)
2119

22-
module.dump(project_dir)
20+
module.dump(tmp_path)
2321
del module
2422

25-
loaded_module = RetrievalAimedEmbedding.load(project_dir)
23+
loaded_module = RetrievalAimedEmbedding.load(tmp_path)
2624
predictions_loaded = loaded_module.predict(utterances)
2725
assert predictions == predictions_loaded
28-
29-
shutil.rmtree(project_dir)

0 commit comments

Comments
 (0)