@@ -79,7 +79,7 @@ def test_dfm_parameter_and_matrix_match(data, k_factors, factor_order, error_ord
7979 pm .Deterministic (
8080 "error_ar" , pt .constant (np .full ((k_endog , error_order ), 0.5 ), dtype = floatX )
8181 )
82- pm .Deterministic ("factor_sigma" , pt .constant (np .full ((k_factors ,), 0.5 ), dtype = floatX ))
82+ # pm.Deterministic("factor_sigma", pt.constant(np.full((k_factors,), 0.5), dtype=floatX))
8383 pm .Deterministic ("error_sigma" , pt .constant (np .full ((k_endog ,), 0.5 ), dtype = floatX ))
8484 pm .Deterministic ("sigma_obs" , pt .constant (np .full ((k_endog ,), 0.5 ), dtype = floatX ))
8585
@@ -96,3 +96,124 @@ def test_dfm_parameter_and_matrix_match(data, k_factors, factor_order, error_ord
9696 atol = 1e-12 ,
9797 err_msg = f"Matrix mismatch: { mat_name } (k_factors={ k_factors } , factor_order={ factor_order } , error_order={ error_order } )" ,
9898 )
99+
100+
101+ @pytest .mark .parametrize ("k_factors" , [1 , 2 ])
102+ @pytest .mark .parametrize ("factor_order" , [0 , 1 , 2 ])
103+ @pytest .mark .parametrize ("error_order" , [1 , 2 , 3 ])
104+ @pytest .mark .filterwarnings ("ignore::statsmodels.tools.sm_exceptions.EstimationWarning" )
105+ @pytest .mark .filterwarnings ("ignore::FutureWarning" )
106+ def test_DFM_update_matches_statsmodels (data , k_factors , factor_order , error_order , rng ):
107+ # --- Fit Statsmodels DynamicFactor with random small params ---
108+ sm_dfm = DynamicFactor (
109+ endog = data ,
110+ k_factors = k_factors ,
111+ factor_order = factor_order ,
112+ error_order = error_order ,
113+ )
114+ param_names = sm_dfm .param_names
115+ param_dict = {param : getattr (np , floatX )(rng .normal (scale = 0.1 ) ** 2 ) for param in param_names }
116+ sm_res = sm_dfm .fit_constrained (param_dict )
117+
118+ # --- Setup BayesianDynamicFactor ---
119+ mod = BayesianDynamicFactor (
120+ k_factors = k_factors ,
121+ factor_order = factor_order ,
122+ error_order = error_order ,
123+ k_endog = data .shape [1 ],
124+ measurement_error = False ,
125+ verbose = False ,
126+ )
127+
128+ # Convert flat param dict to PyTensor variables as needed
129+ # Reshape factor_ar and error_ar parameters according to model expected shapes
130+ factor_ar_shape = (k_factors , factor_order * k_factors )
131+ error_ar_shape = (data .shape [1 ], error_order ) if error_order > 0 else (0 ,)
132+
133+ # Prepare parameter arrays to set as deterministic
134+ # Extract each group of parameters by name pattern (simplified)
135+ factor_loadings = np .array ([param_dict [p ] for p in param_names if "loading" in p ]).reshape (
136+ (data .shape [1 ], k_factors )
137+ )
138+
139+ # Handle factor_ar parameters - need to account for different factor orders
140+ factor_ar_params = []
141+
142+ for factor_idx in range (1 , k_factors + 1 ):
143+ for lag in range (1 , factor_order + 1 ):
144+ for factor_idx2 in range (1 , k_factors + 1 ):
145+ param_pattern = f"L{ lag } .f{ factor_idx2 } .f{ factor_idx } "
146+ if param_pattern in param_names :
147+ factor_ar_params .append (param_pattern )
148+
149+ if len (factor_ar_params ) > 0 :
150+ factor_ar_values = [param_dict [p ] for p in factor_ar_params ]
151+ factor_ar = np .array (factor_ar_values ).reshape (factor_ar_shape )
152+ else :
153+ factor_ar = np .zeros (factor_ar_shape )
154+
155+ # factor_sigma = np.array([param_dict[p] for p in param_names if "factor.sigma" in p])
156+
157+ # Handle error AR parameters - need to account for different error orders and variables
158+ if error_order > 0 :
159+ error_ar_params = []
160+ var_names = [col for col in data .columns ] # Get variable names from data
161+
162+ # Order parameters by variable first, then by lag to match expected shape (n_vars, n_lags)
163+ for var_name in var_names :
164+ for lag in range (1 , error_order + 1 ):
165+ param_pattern = f"L{ lag } .e({ var_name } ).e({ var_name } )"
166+ if param_pattern in param_names :
167+ error_ar_params .append (param_pattern )
168+
169+ if len (error_ar_params ) > 0 :
170+ error_ar_values = [param_dict [p ] for p in error_ar_params ]
171+ error_ar = np .array (error_ar_values ).reshape (error_ar_shape )
172+ else :
173+ error_ar = np .zeros (error_ar_shape )
174+
175+ # Handle observation error variances - look for sigma2 pattern
176+ sigma_obs_params = [p for p in param_names if "sigma2." in p ]
177+ sigma_obs = np .array ([param_dict [p ] for p in sigma_obs_params ])
178+
179+ # Handle error variances (if needed separately from sigma_obs)
180+ if error_order > 0 :
181+ error_sigma = sigma_obs # In this case, error_sigma is the same as sigma_obs
182+
183+ coords = mod .coords
184+ with pm .Model (coords = coords ) as model :
185+ k_states = k_factors * max (1 , factor_order ) + (
186+ error_order * data .shape [1 ] if error_order > 0 else 0
187+ )
188+ pm .Deterministic ("x0" , pt .zeros (k_states , dtype = floatX ))
189+ pm .Deterministic ("P0" , pt .eye (k_states , dtype = floatX ))
190+ # Set deterministic variables with constrained parameter values
191+ pm .Deterministic ("factor_loadings" , pt .as_tensor_variable (factor_loadings ))
192+ if factor_order > 0 :
193+ pm .Deterministic ("factor_ar" , pt .as_tensor_variable (factor_ar ))
194+ # pm.Deterministic("factor_sigma", pt.as_tensor_variable(factor_sigma))
195+ if error_order > 0 :
196+ pm .Deterministic ("error_ar" , pt .as_tensor_variable (error_ar ))
197+ pm .Deterministic ("error_sigma" , pt .as_tensor_variable (error_sigma ))
198+ pm .Deterministic ("sigma_obs" , pt .as_tensor_variable (sigma_obs ))
199+
200+ mod ._insert_random_variables ()
201+
202+ # Draw the substituted state-space matrices from PyMC model
203+ matrices = pm .draw (mod .subbed_ssm )
204+ matrix_dict = dict (zip (SHORT_NAME_TO_LONG .values (), matrices ))
205+
206+ # Matrices to check
207+ matrices_to_check = ["transition" , "selection" , "state_cov" , "obs_cov" , "design" ]
208+
209+ # Compare matrices from PyMC and Statsmodels
210+ for mat_name in matrices_to_check :
211+ sm_mat = np .array (sm_dfm .ssm [mat_name ])
212+ pm_mat = matrix_dict [mat_name ]
213+
214+ assert_allclose (
215+ pm_mat ,
216+ sm_mat ,
217+ atol = 1e-10 ,
218+ err_msg = f"Matrix mismatch: { mat_name } (k_factors={ k_factors } , factor_order={ factor_order } , error_order={ error_order } )" ,
219+ )
0 commit comments