11import numpy as np
22import pytensor
33import pytest
4+ import statsmodels .api as sm
45from numpy .testing import assert_allclose
56from pytensor .graph .basic import explicit_graph_inputs
67from scipy import linalg
8+ from statespace .utils .constants import LONG_MATRIX_NAMES
79
810from pymc_experimental .statespace .models .ETS import BayesianETS
9- from tests .statespace .utilities .test_helpers import (
10- load_nile_test_data ,
11- simulate_from_numpy_model ,
12- )
11+ from tests .statespace .utilities .shared_fixtures import rng
12+ from tests .statespace .utilities .test_helpers import load_nile_test_data
1313
1414
1515@pytest .fixture (scope = "session" )
@@ -43,78 +43,66 @@ def tests_invalid_order_raises():
4343 BayesianETS (order = ("A" , "Ad" , "A" ))
4444
4545
46+ orders = (
47+ ("A" , "N" , "N" ),
48+ ("A" , "A" , "N" ),
49+ ("A" , "Ad" , "N" ),
50+ ("A" , "N" , "A" ),
51+ ("A" , "A" , "A" ),
52+ ("A" , "Ad" , "A" ),
53+ )
54+ order_names = (
55+ "Basic" ,
56+ "Trend" ,
57+ "Damped Trend" ,
58+ "Seasonal" ,
59+ "Trend and Seasonal" ,
60+ "Trend, Damped Trend, Seasonal" ,
61+ )
62+
63+ order_expected_flags = (
64+ {"trend" : False , "damped_trend" : False , "seasonal" : False },
65+ {"trend" : True , "damped_trend" : False , "seasonal" : False },
66+ {"trend" : True , "damped_trend" : True , "seasonal" : False },
67+ {"trend" : False , "damped_trend" : False , "seasonal" : True },
68+ {"trend" : True , "damped_trend" : False , "seasonal" : True },
69+ {"trend" : True , "damped_trend" : True , "seasonal" : True },
70+ )
71+
72+ order_params = (
73+ ["alpha" , "initial_level" ],
74+ ["alpha" , "initial_level" , "beta" , "initial_trend" ],
75+ ["alpha" , "initial_level" , "beta" , "initial_trend" , "phi" ],
76+ ["alpha" , "initial_level" , "gamma" , "initial_seasonal" ],
77+ ["alpha" , "initial_level" , "beta" , "initial_trend" , "gamma" , "initial_seasonal" ],
78+ ["alpha" , "initial_level" , "beta" , "initial_trend" , "gamma" , "initial_seasonal" , "phi" ],
79+ )
80+
81+
4682@pytest .mark .parametrize (
47- "order, expected_flags" ,
48- [
49- (("A" , "N" , "N" ), {"trend" : False , "damped_trend" : False , "seasonal" : False }),
50- (("A" , "A" , "N" ), {"trend" : True , "damped_trend" : False , "seasonal" : False }),
51- (("A" , "Ad" , "N" ), {"trend" : True , "damped_trend" : True , "seasonal" : False }),
52- (("A" , "N" , "A" ), {"trend" : False , "damped_trend" : False , "seasonal" : True }),
53- (("A" , "A" , "A" ), {"trend" : True , "damped_trend" : False , "seasonal" : True }),
54- (("A" , "Ad" , "A" ), {"trend" : True , "damped_trend" : True , "seasonal" : True }),
55- ],
56- ids = [
57- "Basic" ,
58- "Trend" ,
59- "Damped Trend" ,
60- "Seasonal" ,
61- "Trend and Seasonal" ,
62- "Trend, Damped Trend, Seasonal" ,
63- ],
83+ "order, expected_flags" , zip (orders , order_expected_flags ), ids = order_names
6484)
6585def test_order_flags (order , expected_flags ):
6686 mod = BayesianETS (order = order , seasonal_periods = 4 )
6787 for key , value in expected_flags .items ():
6888 assert getattr (mod , key ) == value
6989
7090
71- @pytest .mark .parametrize (
72- "order, expected_params" ,
73- [
74- (("A" , "N" , "N" ), ["alpha" ]),
75- (("A" , "A" , "N" ), ["alpha" , "beta" ]),
76- (("A" , "Ad" , "N" ), ["alpha" , "beta" , "phi" ]),
77- (("A" , "N" , "A" ), ["alpha" , "gamma" ]),
78- (("A" , "A" , "A" ), ["alpha" , "beta" , "gamma" ]),
79- (("A" , "Ad" , "A" ), ["alpha" , "beta" , "gamma" , "phi" ]),
80- ],
81- ids = [
82- "Basic" ,
83- "Trend" ,
84- "Damped Trend" ,
85- "Seasonal" ,
86- "Trend and Seasonal" ,
87- "Trend, Damped Trend, Seasonal" ,
88- ],
89- )
91+ @pytest .mark .parametrize ("order, expected_params" , zip (orders , order_params ), ids = order_names )
9092def test_param_info (order : tuple [str , str , str ], expected_params ):
9193 mod = BayesianETS (order = order , seasonal_periods = 4 )
9294
93- all_expected_params = [* expected_params , "sigma_state" , "x0" , " P0" ]
95+ all_expected_params = [* expected_params , "sigma_state" , "P0" ]
9496 assert all (param in mod .param_names for param in all_expected_params )
9597 assert all (param in all_expected_params for param in mod .param_names )
96- assert all (mod .param_info [param ]["dims" ] is None for param in expected_params )
98+ assert all (
99+ mod .param_info [param ]["dims" ] is None
100+ for param in expected_params
101+ if "seasonal" not in param
102+ )
97103
98104
99- @pytest .mark .parametrize (
100- "order, expected_params" ,
101- [
102- (("A" , "N" , "N" ), ["alpha" ]),
103- (("A" , "A" , "N" ), ["alpha" , "beta" ]),
104- (("A" , "Ad" , "N" ), ["alpha" , "beta" , "phi" ]),
105- (("A" , "N" , "A" ), ["alpha" , "gamma" ]),
106- (("A" , "A" , "A" ), ["alpha" , "beta" , "gamma" ]),
107- (("A" , "Ad" , "A" ), ["alpha" , "beta" , "gamma" , "phi" ]),
108- ],
109- ids = [
110- "Basic" ,
111- "Trend" ,
112- "Damped Trend" ,
113- "Seasonal" ,
114- "Trend and Seasonal" ,
115- "Trend, Damped Trend, Seasonal" ,
116- ],
117- )
105+ @pytest .mark .parametrize ("order, expected_params" , zip (orders , order_params ), ids = order_names )
118106def test_statespace_matrices (order : tuple [str , str , str ], expected_params : list [str ]):
119107 seasonal_periods = np .random .randint (3 , 12 )
120108 mod = BayesianETS (order = order , seasonal_periods = seasonal_periods , measurement_error = True )
@@ -127,7 +115,9 @@ def test_statespace_matrices(order: tuple[str, str, str], expected_params: list[
127115 "phi" : 0.95 ,
128116 "sigma_state" : 0.1 ,
129117 "sigma_obs" : 0.1 ,
130- "x0" : np .zeros (expected_states ),
118+ "initial_level" : 3.0 ,
119+ "initial_trend" : 1.0 ,
120+ "initial_seasonal" : np .ones (seasonal_periods ),
131121 "initial_state_cov" : np .eye (expected_states ),
132122 }
133123
@@ -161,42 +151,91 @@ def test_statespace_matrices(order: tuple[str, str, str], expected_params: list[
161151 Z_val [0 , 0 ] = 1.0
162152 Z_val [0 , 1 ] = 1.0
163153
154+ x0_val = np .zeros ((expected_states ,))
155+ x0_val [1 ] = test_values ["initial_level" ]
156+
164157 if order [1 ] == "N" :
165158 T_val = np .array ([[0.0 , 0.0 ], [0.0 , 1.0 ]])
166159 else :
160+ x0_val [2 ] = test_values ["initial_trend" ]
167161 R_val [2 ] = test_values ["beta" ]
168162 T_val = np .array ([[0.0 , 0.0 , 0.0 ], [0.0 , 1.0 , 1.0 ], [0.0 , 0.0 , 1.0 ]])
169- Z_val [0 , 2 ] = 1.0
170163
171164 if order [1 ] == "Ad" :
172165 T_val [1 :, - 1 ] *= test_values ["phi" ]
173166
174167 if order [2 ] == "A" :
175- R_val [3 ] = test_values ["gamma" ]
168+ x0_val [2 + int (order [1 ] != "N" ) :] = test_values ["initial_seasonal" ]
169+ R_val [2 + int (order [1 ] != "N" )] = test_values ["gamma" ]
176170 S = np .eye (seasonal_periods , k = - 1 )
177- S [0 , : ] = - 1
171+ S [0 , - 1 ] = 1.0
178172 Z_val [0 , 2 + int (order [1 ] != "N" )] = 1.0
179173 else :
180174 S = np .eye (0 )
181175
182176 T_val = linalg .block_diag (T_val , S )
183177
178+ assert_allclose (x0 , x0_val )
184179 assert_allclose (T , T_val )
185180 assert_allclose (R , R_val )
186181 assert_allclose (Z , Z_val )
187182
188183
189- def test_deterministic_simulation_matches_statsmodels ():
190- mod = BayesianETS (order = ("A" , "Ad" , "A" ), seasonal_periods = 4 , measurement_error = False )
184+ @pytest .mark .parametrize ("order, params" , zip (orders , order_params ), ids = order_names )
185+ def test_statespace_matches_statsmodels (rng , order : tuple [str , str , str ], params ):
186+ seasonal_periods = rng .integers (3 , 12 )
187+ data = rng .normal (size = (100 ,))
188+ mod = BayesianETS (order = order , seasonal_periods = seasonal_periods , measurement_error = False )
189+ sm_mod = sm .tsa .statespace .ExponentialSmoothing (
190+ data ,
191+ trend = mod .trend ,
192+ damped_trend = mod .damped_trend ,
193+ seasonal = seasonal_periods if mod .seasonal else None ,
194+ )
195+
196+ simplex_params = ["alpha" , "beta" , "gamma" ]
197+ test_values = dict (zip (simplex_params , rng .dirichlet (alpha = np .ones (3 ))))
198+ test_values ["phi" ] = rng .beta (1 , 1 )
199+
200+ test_values ["initial_level" ] = rng .normal ()
201+ test_values ["initial_trend" ] = rng .normal ()
202+ test_values ["initial_seasonal" ] = rng .normal (size = seasonal_periods )
203+ test_values ["initial_state_cov" ] = np .eye (mod .k_states )
204+ test_values ["sigma_state" ] = 1.0
205+
206+ sm_test_values = test_values .copy ()
207+ sm_test_values ["smoothing_level" ] = test_values ["alpha" ]
208+ sm_test_values ["smoothing_trend" ] = test_values ["beta" ]
209+ sm_test_values ["smoothing_seasonal" ] = test_values ["gamma" ]
210+ sm_test_values ["damping_trend" ] = test_values ["phi" ]
211+ sm_test_values ["initial_seasonal" ] = test_values ["initial_seasonal" ][0 ]
212+ for i in range (1 , seasonal_periods ):
213+ sm_test_values [f"initial_seasonal.L{ i } " ] = test_values ["initial_seasonal" ][i ]
214+
215+ x0 = np .r_ [
216+ 0 , * [test_values [name ] for name in ["initial_level" , "initial_trend" , "initial_seasonal" ]]
217+ ]
218+ mask = [True , True , order [1 ] != "N" , * (order [2 ] != "N" ,) * seasonal_periods ]
219+
220+ sm_mod .initialize_known (initial_state = x0 [mask ], initial_state_cov = np .eye (mod .k_states ))
221+ sm_mod .fit_constrained ({name : sm_test_values [name ] for name in sm_mod .param_names })
222+
223+ matrices = mod ._unpack_statespace_with_placeholders ()
224+ inputs = list (explicit_graph_inputs (matrices ))
225+ input_names = [x .name for x in inputs ]
191226
192- rng = np .random .default_rng ()
193- test_values = {
194- "alpha" : 0.7 ,
195- "beta" : 0.15 ,
196- "gamma" : 0.15 ,
197- "phi" : 0.95 ,
198- "sigma_state" : 0.0 ,
199- "x0" : rng .normal (size = (7 ,)),
200- "initial_state_cov" : np .eye (7 ),
201- }
202- hidden_states , observed = simulate_from_numpy_model (mod , rng , test_values )
227+ f_matrices = pytensor .function (inputs , matrices )
228+ test_values_subset = {name : test_values [name ] for name in input_names }
229+
230+ matrices = f_matrices (** test_values_subset )
231+ sm_matrices = [sm_mod .ssm [name ] for name in LONG_MATRIX_NAMES [2 :]]
232+
233+ for matrix , sm_matrix , name in zip (matrices [2 :], sm_matrices , LONG_MATRIX_NAMES [2 :]):
234+ if name == "selection" :
235+ # statsmodel selection matrix seems to be wrong? They set the first element of the selection matrix to
236+ # 1 - sum(alpha, beta, gamma), which doesn't match the equations presented in ffp3
237+ assert_allclose (matrix [1 :], sm_matrix [1 :], err_msg = f"{ name } does not match" )
238+ assert matrix [0 ] == 1.0
239+ assert sm_matrix [0 ] != 1.0
240+ else :
241+ assert_allclose (matrix , sm_matrix , err_msg = f"{ name } does not match" )
0 commit comments