Skip to content

Commit d383052

Browse files
committed
Add test_timeseries_model.py
1 parent 342b808 commit d383052

File tree

2 files changed

+197
-1
lines changed

2 files changed

+197
-1
lines changed

src/pownet/stochastic/timeseries_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def _predict(self) -> pd.Series:
160160
pass
161161

162162
@abstractmethod
163-
def _get_synthetic(self, exog_vars: list[str], seed: int) -> pd.Series:
163+
def _get_synthetic(self, exog_data: pd.DataFrame, seed: int) -> pd.Series:
164164
pass
165165

166166
@abstractmethod
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""test_timeseries_model.py"""
2+
3+
import unittest
4+
import pandas as pd
5+
from pownet.stochastic import timeseries_model
6+
7+
8+
class ConcreteTimeSeriesModel(timeseries_model.TimeSeriesModel):
9+
"""A concrete implementation for testing the abstract base class."""
10+
11+
def __init__(self):
12+
super().__init__()
13+
self._predictions = pd.Series(dtype=float)
14+
self._residuals = pd.Series(dtype=float)
15+
self._monthly_models = {}
16+
17+
@property
18+
def monthly_models(self) -> dict:
19+
return self._monthly_models
20+
21+
@property
22+
def predictions(self) -> pd.Series:
23+
return self._predictions
24+
25+
@property
26+
def pred_residuals(self) -> pd.Series:
27+
return self._residuals
28+
29+
def _fit(
30+
self,
31+
target_column: str,
32+
arima_order: tuple[int, int, int],
33+
seasonal_order: tuple[int, int, int, int],
34+
exog_vars: list[str],
35+
) -> None:
36+
# Minimal implementation for testing purposes
37+
# In a real scenario, this would fit some model
38+
self._monthly_models[1] = "dummy_model_for_month_1"
39+
pass
40+
41+
def _predict(self) -> pd.Series:
42+
# Minimal implementation
43+
if not self.data.empty:
44+
return pd.Series(
45+
[1.0] * len(self.data), index=self.data.index, name="predictions"
46+
)
47+
return pd.Series(dtype=float)
48+
49+
def _get_synthetic(
50+
self, exog_data: pd.DataFrame = None, seed: int = None
51+
) -> pd.Series:
52+
# For this dummy, we'll keep the logic based on self.data, as it doesn't
53+
# actually use the exogenous variables for its dummy output.
54+
if not self.data.empty:
55+
return pd.Series(
56+
[0.5] * len(self.data), index=self.data.index, name="synthetic"
57+
)
58+
return pd.Series(dtype=float)
59+
60+
def _find_best_model(
61+
self,
62+
target_column: str,
63+
exog_vars: list[str],
64+
month_to_use: int,
65+
seed: int,
66+
suppress_warnings: bool,
67+
) -> tuple[tuple[int, int, int], tuple[int, int, int, int]]:
68+
# Minimal implementation
69+
return ((1, 0, 0), (0, 0, 0, 0))
70+
71+
72+
class TestTimeSeriesModel(unittest.TestCase):
73+
74+
def setUp(self):
75+
self.model = ConcreteTimeSeriesModel()
76+
self.sample_data = pd.DataFrame(
77+
{
78+
"datetime": pd.to_datetime(
79+
[
80+
"2023-01-01 00:00:00",
81+
"2023-01-01 01:00:00",
82+
"2023-01-01 02:00:00", # Changed this line
83+
"2023-01-01 03:00:00", # Added an extra point for more data
84+
]
85+
),
86+
"value": [10, 12, 15, 11], # Adjusted values
87+
"exog1": [1, 2, 3, 4], # Adjusted exog
88+
}
89+
)
90+
self.target_column = "value"
91+
92+
def test_initialization(self):
93+
self.assertFalse(self.model._is_fitted)
94+
self.assertFalse(self.model._is_loaded)
95+
self.assertTrue(self.model.data.empty)
96+
self.assertEqual(self.model.months, [])
97+
self.assertIsNone(self.model.exog_vars)
98+
99+
def test_load_data_success(self):
100+
self.model.load_data(self.sample_data.copy())
101+
self.assertTrue(self.model._is_loaded)
102+
self.assertFalse(self.model.data.empty)
103+
self.assertIn(pd.Timestamp("2023-01-01 00:00:00"), self.model.data.index)
104+
self.assertEqual(self.model.data.index.freqstr, "h")
105+
self.assertEqual(self.model.months, [1]) # Sorted
106+
self.assertIsInstance(self.model.data.index, pd.DatetimeIndex)
107+
108+
def test_load_data_missing_datetime_column(self):
109+
bad_data = pd.DataFrame({"val": [1, 2]})
110+
with self.assertRaisesRegex(ValueError, "Data should have columns 'datetime'"):
111+
self.model.load_data(bad_data)
112+
113+
def test_fit_success(self):
114+
self.model.load_data(self.sample_data.copy())
115+
self.model.fit(
116+
target_column=self.target_column, arima_order=(1, 0, 0), exog_vars=["exog1"]
117+
)
118+
self.assertTrue(self.model._is_fitted)
119+
self.assertEqual(self.model.exog_vars, ["exog1"])
120+
# You might also check if the dummy _fit method was "called" (e.g., by checking its side effects)
121+
self.assertIn(1, self.model.monthly_models) # Based on dummy _fit
122+
123+
def test_predict_not_fitted(self):
124+
with self.assertRaisesRegex(
125+
ValueError, "Model must be fitted before making predictions."
126+
):
127+
self.model.predict()
128+
129+
def test_predict_success(self):
130+
self.model.load_data(self.sample_data.copy())
131+
self.model.fit(target_column=self.target_column, arima_order=(1, 0, 0))
132+
predictions = self.model.predict()
133+
self.assertIsInstance(predictions, pd.Series)
134+
self.assertEqual(
135+
len(predictions), len(self.model.data)
136+
) # Based on dummy _predict
137+
138+
def test_get_synthetic_not_fitted(self):
139+
self.model.load_data(self.sample_data.copy())
140+
with self.assertRaisesRegex(
141+
ValueError, "Model must be fitted before creating synthetic data."
142+
):
143+
self.model.get_synthetic()
144+
145+
def test_get_synthetic_exog_vars_mismatch(self):
146+
self.model.load_data(self.sample_data.copy())
147+
self.model.fit(
148+
target_column=self.target_column,
149+
arima_order=(1, 0, 0),
150+
exog_vars=["exog_missing"],
151+
)
152+
exog_df = pd.DataFrame(
153+
{"exog_other": [1, 1, 1, 1]}, index=self.model.data.index
154+
)
155+
with self.assertRaisesRegex(
156+
ValueError, "Exogenous variables should be in the data."
157+
):
158+
self.model.get_synthetic(exog_data=exog_df)
159+
160+
def test_get_synthetic_exog_index_mismatch(self):
161+
self.model.load_data(self.sample_data.copy())
162+
self.model.fit(
163+
target_column=self.target_column, arima_order=(1, 0, 0), exog_vars=["exog1"]
164+
)
165+
wrong_index = pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"])
166+
exog_df = pd.DataFrame(
167+
{"exog1": [1, 1, 1]}, index=wrong_index
168+
) # Different index
169+
with self.assertRaisesRegex(
170+
ValueError,
171+
"Exogenous data should have the same index as the time series data.",
172+
):
173+
self.model.get_synthetic(exog_data=exog_df)
174+
175+
def test_get_synthetic_success(self):
176+
self.model.load_data(self.sample_data.copy())
177+
self.model.fit(target_column=self.target_column, arima_order=(1, 0, 0))
178+
synthetic_data = self.model.get_synthetic()
179+
self.assertIsInstance(synthetic_data, pd.Series)
180+
self.assertEqual(len(synthetic_data), len(self.model.data))
181+
182+
def test_find_best_model_not_loaded(self):
183+
with self.assertRaisesRegex(ValueError, "Data must be loaded first."):
184+
self.model.find_best_model(target_column=self.target_column)
185+
186+
def test_find_best_model_success(self):
187+
self.model.load_data(self.sample_data.copy())
188+
order, seasonal_order = self.model.find_best_model(
189+
target_column=self.target_column
190+
)
191+
self.assertEqual(order, (1, 0, 0)) # From dummy implementation
192+
self.assertEqual(seasonal_order, (0, 0, 0, 0)) # From dummy implementation
193+
194+
195+
if __name__ == "__main__":
196+
unittest.main()

0 commit comments

Comments
 (0)