11import shutil
2- from pathlib import Path
32from unittest .mock import MagicMock
43
54import numpy as np
98
109
1110def test_get_assets_returns_correct_artifact_for_logreg ():
12- db_dir , dump_dir , logs_dir = setup_environment ()
13- module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" , db_dir = db_dir )
11+ module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" )
1412 artifact = module .get_assets ()
1513 assert artifact .embedder_name == "sergeyzh/rubert-tiny-turbo"
1614
1715
1816def test_fit_trains_model ():
19- db_dir , dump_dir , logs_dir = setup_environment ()
20- module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" , db_dir = db_dir )
17+ module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" )
2118
2219 utterances = ["hello" , "goodbye" , "hi" , "bye" , "bye" , "hello" , "welcome" , "hi123" , "hiii" , "bye-bye" , "bye!" ]
2320 labels = [0 , 1 , 0 , 1 , 1 , 0 , 0 , 0 , 0 , 1 , 1 ]
@@ -29,8 +26,7 @@ def test_fit_trains_model():
2926
3027
3128def test_score_evaluates_model ():
32- db_dir , dump_dir , logs_dir = setup_environment ()
33- module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" , db_dir = db_dir )
29+ module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" )
3430
3531 utterances = ["hello" , "goodbye" , "hi" , "bye" , "bye" , "hello" , "welcome" , "hi123" , "hiii" , "bye-bye" , "bye!" ]
3632 labels = [0 , 1 , 0 , 1 , 1 , 0 , 0 , 0 , 0 , 1 , 1 ]
@@ -50,24 +46,22 @@ def mock_metric_fn(true_labels, predicted_labels):
5046
5147
5248def test_dump_and_load_preserves_model_state ():
53- db_dir , dump_dir , logs_dir = setup_environment ()
54- module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" , db_dir = db_dir )
49+ dump_dir , _ = setup_environment ()
50+ module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" )
5551
5652 utterances = ["hello" , "goodbye" , "hi" , "bye" , "bye" , "hello" , "welcome" , "hi123" , "hiii" , "bye-bye" , "bye!" ]
5753 labels = [0 , 1 , 0 , 1 , 1 , 0 , 0 , 0 , 0 , 1 , 1 ]
5854 module .fit (utterances , labels )
5955
60- dump_path = Path (dump_dir )
61- dump_path .mkdir (parents = True , exist_ok = True )
62- module .dump (str (dump_path ))
56+ module .dump (dump_dir )
6357
64- loaded_module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" , db_dir = db_dir )
65- loaded_module .load (str ( dump_path ) )
58+ loaded_module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" )
59+ loaded_module .load (dump_dir )
6660 epsilon = 1e-6
6761
6862 assert np .allclose (loaded_module .classifier .coef_ , module .classifier .coef_ , atol = epsilon )
6963 assert np .allclose (loaded_module .classifier .intercept_ , module .classifier .intercept_ , atol = epsilon )
7064 assert np .array_equal (np .array (loaded_module .label_encoder .classes_ ), np .array (module .label_encoder .classes_ ))
7165 assert loaded_module .embedder_name == module .embedder_name
7266
73- shutil .rmtree (dump_path )
67+ shutil .rmtree (dump_dir )
0 commit comments