@@ -79,7 +79,7 @@ def test_dfm_parameter_and_matrix_match(data, k_factors, factor_order, error_ord
79
79
pm .Deterministic (
80
80
"error_ar" , pt .constant (np .full ((k_endog , error_order ), 0.5 ), dtype = floatX )
81
81
)
82
- pm .Deterministic ("factor_sigma" , pt .constant (np .full ((k_factors ,), 0.5 ), dtype = floatX ))
82
+ # pm.Deterministic("factor_sigma", pt.constant(np.full((k_factors,), 0.5), dtype=floatX))
83
83
pm .Deterministic ("error_sigma" , pt .constant (np .full ((k_endog ,), 0.5 ), dtype = floatX ))
84
84
pm .Deterministic ("sigma_obs" , pt .constant (np .full ((k_endog ,), 0.5 ), dtype = floatX ))
85
85
@@ -96,3 +96,124 @@ def test_dfm_parameter_and_matrix_match(data, k_factors, factor_order, error_ord
96
96
atol = 1e-12 ,
97
97
err_msg = f"Matrix mismatch: { mat_name } (k_factors={ k_factors } , factor_order={ factor_order } , error_order={ error_order } )" ,
98
98
)
99
+
100
+
101
+ @pytest .mark .parametrize ("k_factors" , [1 , 2 ])
102
+ @pytest .mark .parametrize ("factor_order" , [0 , 1 , 2 ])
103
+ @pytest .mark .parametrize ("error_order" , [1 , 2 , 3 ])
104
+ @pytest .mark .filterwarnings ("ignore::statsmodels.tools.sm_exceptions.EstimationWarning" )
105
+ @pytest .mark .filterwarnings ("ignore::FutureWarning" )
106
+ def test_DFM_update_matches_statsmodels (data , k_factors , factor_order , error_order , rng ):
107
+ # --- Fit Statsmodels DynamicFactor with random small params ---
108
+ sm_dfm = DynamicFactor (
109
+ endog = data ,
110
+ k_factors = k_factors ,
111
+ factor_order = factor_order ,
112
+ error_order = error_order ,
113
+ )
114
+ param_names = sm_dfm .param_names
115
+ param_dict = {param : getattr (np , floatX )(rng .normal (scale = 0.1 ) ** 2 ) for param in param_names }
116
+ sm_res = sm_dfm .fit_constrained (param_dict )
117
+
118
+ # --- Setup BayesianDynamicFactor ---
119
+ mod = BayesianDynamicFactor (
120
+ k_factors = k_factors ,
121
+ factor_order = factor_order ,
122
+ error_order = error_order ,
123
+ k_endog = data .shape [1 ],
124
+ measurement_error = False ,
125
+ verbose = False ,
126
+ )
127
+
128
+ # Convert flat param dict to PyTensor variables as needed
129
+ # Reshape factor_ar and error_ar parameters according to model expected shapes
130
+ factor_ar_shape = (k_factors , factor_order * k_factors )
131
+ error_ar_shape = (data .shape [1 ], error_order ) if error_order > 0 else (0 ,)
132
+
133
+ # Prepare parameter arrays to set as deterministic
134
+ # Extract each group of parameters by name pattern (simplified)
135
+ factor_loadings = np .array ([param_dict [p ] for p in param_names if "loading" in p ]).reshape (
136
+ (data .shape [1 ], k_factors )
137
+ )
138
+
139
+ # Handle factor_ar parameters - need to account for different factor orders
140
+ factor_ar_params = []
141
+
142
+ for factor_idx in range (1 , k_factors + 1 ):
143
+ for lag in range (1 , factor_order + 1 ):
144
+ for factor_idx2 in range (1 , k_factors + 1 ):
145
+ param_pattern = f"L{ lag } .f{ factor_idx2 } .f{ factor_idx } "
146
+ if param_pattern in param_names :
147
+ factor_ar_params .append (param_pattern )
148
+
149
+ if len (factor_ar_params ) > 0 :
150
+ factor_ar_values = [param_dict [p ] for p in factor_ar_params ]
151
+ factor_ar = np .array (factor_ar_values ).reshape (factor_ar_shape )
152
+ else :
153
+ factor_ar = np .zeros (factor_ar_shape )
154
+
155
+ # factor_sigma = np.array([param_dict[p] for p in param_names if "factor.sigma" in p])
156
+
157
+ # Handle error AR parameters - need to account for different error orders and variables
158
+ if error_order > 0 :
159
+ error_ar_params = []
160
+ var_names = [col for col in data .columns ] # Get variable names from data
161
+
162
+ # Order parameters by variable first, then by lag to match expected shape (n_vars, n_lags)
163
+ for var_name in var_names :
164
+ for lag in range (1 , error_order + 1 ):
165
+ param_pattern = f"L{ lag } .e({ var_name } ).e({ var_name } )"
166
+ if param_pattern in param_names :
167
+ error_ar_params .append (param_pattern )
168
+
169
+ if len (error_ar_params ) > 0 :
170
+ error_ar_values = [param_dict [p ] for p in error_ar_params ]
171
+ error_ar = np .array (error_ar_values ).reshape (error_ar_shape )
172
+ else :
173
+ error_ar = np .zeros (error_ar_shape )
174
+
175
+ # Handle observation error variances - look for sigma2 pattern
176
+ sigma_obs_params = [p for p in param_names if "sigma2." in p ]
177
+ sigma_obs = np .array ([param_dict [p ] for p in sigma_obs_params ])
178
+
179
+ # Handle error variances (if needed separately from sigma_obs)
180
+ if error_order > 0 :
181
+ error_sigma = sigma_obs # In this case, error_sigma is the same as sigma_obs
182
+
183
+ coords = mod .coords
184
+ with pm .Model (coords = coords ) as model :
185
+ k_states = k_factors * max (1 , factor_order ) + (
186
+ error_order * data .shape [1 ] if error_order > 0 else 0
187
+ )
188
+ pm .Deterministic ("x0" , pt .zeros (k_states , dtype = floatX ))
189
+ pm .Deterministic ("P0" , pt .eye (k_states , dtype = floatX ))
190
+ # Set deterministic variables with constrained parameter values
191
+ pm .Deterministic ("factor_loadings" , pt .as_tensor_variable (factor_loadings ))
192
+ if factor_order > 0 :
193
+ pm .Deterministic ("factor_ar" , pt .as_tensor_variable (factor_ar ))
194
+ # pm.Deterministic("factor_sigma", pt.as_tensor_variable(factor_sigma))
195
+ if error_order > 0 :
196
+ pm .Deterministic ("error_ar" , pt .as_tensor_variable (error_ar ))
197
+ pm .Deterministic ("error_sigma" , pt .as_tensor_variable (error_sigma ))
198
+ pm .Deterministic ("sigma_obs" , pt .as_tensor_variable (sigma_obs ))
199
+
200
+ mod ._insert_random_variables ()
201
+
202
+ # Draw the substituted state-space matrices from PyMC model
203
+ matrices = pm .draw (mod .subbed_ssm )
204
+ matrix_dict = dict (zip (SHORT_NAME_TO_LONG .values (), matrices ))
205
+
206
+ # Matrices to check
207
+ matrices_to_check = ["transition" , "selection" , "state_cov" , "obs_cov" , "design" ]
208
+
209
+ # Compare matrices from PyMC and Statsmodels
210
+ for mat_name in matrices_to_check :
211
+ sm_mat = np .array (sm_dfm .ssm [mat_name ])
212
+ pm_mat = matrix_dict [mat_name ]
213
+
214
+ assert_allclose (
215
+ pm_mat ,
216
+ sm_mat ,
217
+ atol = 1e-10 ,
218
+ err_msg = f"Matrix mismatch: { mat_name } (k_factors={ k_factors } , factor_order={ factor_order } , error_order={ error_order } )" ,
219
+ )
0 commit comments