Skip to content

Commit b64d8be

Browse files
Add unit tests for save and load utils functions
1 parent b4590bf commit b64d8be

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Test persistence module."""
2+
3+
import pickle
4+
5+
import pytest # type: ignore
6+
7+
from frouros.callbacks import HistoryConceptDrift, PermutationTestDistanceBased
8+
from frouros.callbacks.base import BaseCallback
9+
from frouros.detectors.base import BaseDetector
10+
from frouros.detectors.concept_drift import DDM, DDMConfig
11+
from frouros.detectors.data_drift import MMD
12+
from frouros.utils import load, save
13+
14+
15+
@pytest.fixture(
16+
scope="module",
17+
params=[
18+
DDM(
19+
config=DDMConfig(),
20+
),
21+
MMD(),
22+
],
23+
)
24+
def detector(
25+
request: pytest.FixtureRequest,
26+
) -> BaseDetector:
27+
"""Fixture for detector.
28+
29+
:param request: Request
30+
:type request: pytest.FixtureRequest
31+
:return: Detector
32+
:rtype: BaseDetector
33+
"""
34+
return request.param
35+
36+
37+
@pytest.fixture(
38+
scope="module",
39+
params=[
40+
HistoryConceptDrift(),
41+
PermutationTestDistanceBased(
42+
num_permutations=2,
43+
),
44+
],
45+
)
46+
def callback(
47+
request: pytest.FixtureRequest,
48+
) -> BaseCallback:
49+
"""Fixture for callback.
50+
51+
:param request: Request
52+
:type request: pytest.FixtureRequest
53+
:return: Callback
54+
:rtype: BaseCallback
55+
"""
56+
return request.param
57+
58+
59+
def test_save_load_with_valid_detector(
60+
detector: BaseDetector,
61+
) -> None:
62+
"""Test save and load with valid detector.
63+
64+
:param detector: Detector
65+
:type detector: BaseDetector
66+
"""
67+
filename = "/tmp/detector.pkl"
68+
save(detector, filename)
69+
loaded_detector = load(filename)
70+
assert isinstance(loaded_detector, detector.__class__)
71+
72+
73+
def test_save_load_with_valid_callback(
74+
callback: BaseCallback,
75+
) -> None:
76+
"""Test save and load with valid callback.
77+
78+
:param callback: Callback
79+
:type callback: BaseCallback
80+
"""
81+
filename = "/tmp/callback.pkl"
82+
save(callback, filename)
83+
loaded_callback = load(filename)
84+
assert isinstance(loaded_callback, BaseCallback)
85+
86+
87+
def test_save_with_invalid_object() -> None:
88+
"""Test save with invalid object.
89+
90+
:raises TypeError: Type error exception
91+
"""
92+
invalid_object = "invalid"
93+
filename = "/tmp/invalid.pkl"
94+
with pytest.raises(TypeError):
95+
save(invalid_object, filename)
96+
97+
98+
def test_save_with_invalid_protocol(
99+
detector: BaseDetector,
100+
) -> None:
101+
"""Test save with invalid protocol.
102+
103+
:param detector: Detector
104+
:type detector: BaseDetector
105+
:raises ValueError: Value error exception
106+
"""
107+
filename = "/tmp/detector.pkl"
108+
invalid_protocol = pickle.HIGHEST_PROTOCOL + 1
109+
with pytest.raises(ValueError):
110+
save(detector, filename, invalid_protocol)
111+
112+
113+
def test_load_with_non_existent_file() -> None:
114+
"""Test load with non-existent file.
115+
116+
:raises FileNotFoundError: File not found error exception
117+
"""
118+
filename = "/tmp/non_existent.pkl"
119+
with pytest.raises(FileNotFoundError):
120+
load(filename)

0 commit comments

Comments
 (0)