1+ from contextlib import contextmanager
2+
13import arviz as az
24import numpy as np
35import pymc as pm
1618)
1719
1820
21+ @contextmanager
22+ def no_op ():
23+ yield
24+
25+
1926@pytest .fixture
2027def rng ():
2128 return np .random .default_rng ()
@@ -62,14 +69,18 @@ def hierarchical_model(rng):
6269 return model , mu_val , H_inv , test_point
6370
6471
65- def test_laplace_draws_to_inferencedata ( simple_model , rng ):
66- # Simulate posterior draws: 2 variables, each (chains, draws)
72+ @ pytest . mark . parametrize ( "use_context" , [ False , True ], ids = [ "model_arg" , "model_context" ])
73+ def test_laplace_draws_to_inferencedata ( use_context , simple_model , rng ):
6774 chains , draws = 2 , 5
6875 mu_draws = rng .normal (size = (chains , draws ))
6976 sigma_draws = np .abs (rng .normal (size = (chains , draws )))
7077 model , * _ = simple_model
7178
72- idata = laplace_draws_to_inferencedata ([mu_draws , sigma_draws ], model = model )
79+ context = model if use_context else no_op ()
80+ model_arg = model if not use_context else None
81+
82+ with context :
83+ idata = laplace_draws_to_inferencedata ([mu_draws , sigma_draws ], model = model_arg )
7384
7485 assert isinstance (idata , az .InferenceData )
7586 assert "mu" in idata .posterior
@@ -93,14 +104,21 @@ def check_idata(self, idata, var_names, n_vars):
93104 assert fit .coords ["rows" ].values .tolist () == var_names
94105 assert fit .coords ["columns" ].values .tolist () == var_names
95106
96- def test_add_fit_to_inferencedata (self , simple_model , rng ):
107+ @pytest .mark .parametrize ("use_context" , [False , True ], ids = ["model_arg" , "model_context" ])
108+ def test_add_fit_to_inferencedata (self , use_context , simple_model , rng ):
97109 model , mu_val , H_inv , test_point = simple_model
98110 idata = az .from_dict (posterior = {"mu" : rng .normal (size = ()), "sigma" : rng .normal (size = ())})
99- idata2 = add_fit_to_inference_data (idata , test_point , H_inv , model = model )
111+
112+ context = model if use_context else no_op ()
113+ model_arg = model if not use_context else None
114+
115+ with context :
116+ idata2 = add_fit_to_inference_data (idata , test_point , H_inv , model = model_arg )
100117
101118 self .check_idata (idata2 , ["mu" , "sigma" ], 2 )
102119
103- def test_add_fit_with_coords_to_inferencedata (self , hierarchical_model , rng ):
120+ @pytest .mark .parametrize ("use_context" , [False , True ], ids = ["model_arg" , "model_context" ])
121+ def test_add_fit_with_coords_to_inferencedata (self , use_context , hierarchical_model , rng ):
104122 model , mu_val , H_inv , test_point = hierarchical_model
105123 idata = az .from_dict (
106124 posterior = {
@@ -111,26 +129,38 @@ def test_add_fit_with_coords_to_inferencedata(self, hierarchical_model, rng):
111129 }
112130 )
113131
114- idata2 = add_fit_to_inference_data (idata , test_point , H_inv , model = model )
132+ context = model if use_context else no_op ()
133+ model_arg = model if not use_context else None
134+
135+ with context :
136+ idata2 = add_fit_to_inference_data (idata , test_point , H_inv , model = model_arg )
115137
116138 self .check_idata (
117139 idata2 , ["mu_loc" , "mu_scale" , "mu[1]" , "mu[2]" , "mu[3]" , "mu[4]" , "mu[5]" , "sigma" ], 8
118140 )
119141
120142
121- def test_add_data_to_inferencedata (simple_model , rng ):
143+ @pytest .mark .parametrize ("use_context" , [False , True ], ids = ["model_arg" , "model_context" ])
144+ def test_add_data_to_inferencedata (use_context , simple_model , rng ):
122145 model , * _ = simple_model
123146
124147 idata = az .from_dict (
125148 posterior = {"mu" : rng .standard_normal ((1 , 1 )), "sigma" : rng .standard_normal ((1 , 1 ))}
126149 )
127- idata2 = add_data_to_inference_data (idata , model = model )
150+
151+ context = model if use_context else no_op ()
152+ model_arg = model if not use_context else None
153+
154+ with context :
155+ idata2 = add_data_to_inference_data (idata , model = model_arg )
156+
128157 assert "observed_data" in idata2 .groups ()
129158 assert "constant_data" in idata2 .groups ()
130159 assert "obs" in idata2 .observed_data
131160
132161
133- def test_optimizer_result_to_dataset_basic (simple_model , rng ):
162+ @pytest .mark .parametrize ("use_context" , [False , True ], ids = ["model_arg" , "model_context" ])
163+ def test_optimizer_result_to_dataset_basic (use_context , simple_model , rng ):
134164 model , mu_val , H_inv , test_point = simple_model
135165 result = OptimizeResult (
136166 x = np .array ([1.0 , 2.0 ]),
@@ -144,7 +174,11 @@ def test_optimizer_result_to_dataset_basic(simple_model, rng):
144174 status = 0 ,
145175 )
146176
147- ds = optimizer_result_to_dataset (result , method = "BFGS" , model = model , mu = test_point )
177+ context = model if use_context else no_op ()
178+ model_arg = model if not use_context else None
179+ with context :
180+ ds = optimizer_result_to_dataset (result , method = "BFGS" , model = model_arg , mu = test_point )
181+
148182 assert isinstance (ds , xr .Dataset )
149183 assert all (
150184 key in ds
@@ -169,48 +203,68 @@ def test_optimizer_result_to_dataset_basic(simple_model, rng):
169203 assert ds ["jac" ].coords ["variables" ].values .tolist () == ["mu" , "sigma" ]
170204
171205
172- def test_optimizer_result_to_dataset_hess_inv_matrix (hierarchical_model , rng ):
173- model , mu_val , H_inv , test_point = hierarchical_model
174- result = OptimizeResult (
175- x = np .zeros ((8 ,)),
176- hess_inv = np .eye (8 ),
206+ @pytest .mark .parametrize (
207+ "optimizer_method, use_context, model_name" ,
208+ [("BFGS" , True , "hierarchical_model" ), ("L-BFGS-B" , False , "simple_model" )],
209+ )
210+ def test_optimizer_result_to_dataset_hess_inv_types (
211+ optimizer_method , use_context , model_name , rng , request
212+ ):
213+ def get_hess_inv_and_expected_names (method ):
214+ model , mu_val , H_inv , test_point = request .getfixturevalue (model_name )
215+ n = mu_val .shape [0 ]
216+
217+ if method == "BFGS" :
218+ hess_inv = np .eye (n )
219+ expected_names = [
220+ "mu_loc" ,
221+ "mu_scale" ,
222+ "mu[1]" ,
223+ "mu[2]" ,
224+ "mu[3]" ,
225+ "mu[4]" ,
226+ "mu[5]" ,
227+ "sigma" ,
228+ ]
229+ result = OptimizeResult (
230+ x = np .zeros ((n ,)),
231+ hess_inv = hess_inv ,
232+ )
233+ elif method == "L-BFGS-B" :
234+
235+ def linop_func (x ):
236+ return np .array ([2 * xi for xi in x ])
237+
238+ linop = LinearOperator ((n , n ), matvec = linop_func )
239+ hess_inv = 2 * np .eye (n )
240+ expected_names = ["mu" , "sigma" ]
241+ result = OptimizeResult (
242+ x = np .ones (n ),
243+ hess_inv = linop ,
244+ )
245+ else :
246+ raise ValueError ("Unknown optimizer_method" )
247+
248+ return model , test_point , hess_inv , expected_names , result
249+
250+ model , test_point , hess_inv , expected_names , result = get_hess_inv_and_expected_names (
251+ optimizer_method
177252 )
178- ds = optimizer_result_to_dataset (result , method = "BFGS" , model = model , mu = test_point )
179253
180- assert "hess_inv" in ds
181- assert ds ["hess_inv" ].shape == (8 , 8 )
182- assert list (ds ["hess_inv" ].coords .keys ()) == ["variables" , "variables_aux" ]
183-
184- expected_names = ["mu_loc" , "mu_scale" , "mu[1]" , "mu[2]" , "mu[3]" , "mu[4]" , "mu[5]" , "sigma" ]
185- assert ds ["hess_inv" ].coords ["variables" ].values .tolist () == expected_names
186- assert ds ["hess_inv" ].coords ["variables_aux" ].values .tolist () == expected_names
187-
188-
189- def test_optimizer_result_to_dataset_hess_inv_linear_operator (simple_model , rng ):
190- model , mu_val , H_inv , test_point = simple_model
191- n = mu_val .shape [0 ]
192-
193- def matvec (x ):
194- return np .array ([2 * xi for xi in x ])
195-
196- linop = LinearOperator ((n , n ), matvec = matvec )
197- result = OptimizeResult (
198- x = np .ones (n ),
199- hess_inv = linop ,
200- )
254+ context = model if use_context else no_op ()
255+ model_arg = model if not use_context else None
201256
202- with model :
203- ds = optimizer_result_to_dataset (result , method = "BFGS" , mu = test_point )
257+ with context :
258+ ds = optimizer_result_to_dataset (
259+ result , method = optimizer_method , mu = test_point , model = model_arg
260+ )
204261
205262 assert "hess_inv" in ds
206- assert ds ["hess_inv" ].shape == (n , n )
263+ assert ds ["hess_inv" ].shape == (len ( expected_names ), len ( expected_names ) )
207264 assert list (ds ["hess_inv" ].coords .keys ()) == ["variables" , "variables_aux" ]
208-
209- expected_names = ["mu" , "sigma" ]
210265 assert ds ["hess_inv" ].coords ["variables" ].values .tolist () == expected_names
211266 assert ds ["hess_inv" ].coords ["variables_aux" ].values .tolist () == expected_names
212-
213- np .testing .assert_allclose (ds ["hess_inv" ].values , 2 * np .eye (n ))
267+ np .testing .assert_allclose (ds ["hess_inv" ].values , hess_inv )
214268
215269
216270def test_optimizer_result_to_dataset_extra_fields (simple_model , rng ):
@@ -228,3 +282,25 @@ def test_optimizer_result_to_dataset_extra_fields(simple_model, rng):
228282 assert ds ["custom_stat" ].shape == (2 ,)
229283 assert list (ds ["custom_stat" ].coords .keys ()) == ["custom_stat_dim_0" ]
230284 assert ds ["custom_stat" ].coords ["custom_stat_dim_0" ].values .tolist () == [0 , 1 ]
285+
286+
287+ def test_optimizer_result_to_dataset_hess_inv_basinhopping (simple_model , rng ):
288+ model , mu_val , H_inv , test_point = simple_model
289+ n = mu_val .shape [0 ]
290+ hess_inv_inner = np .eye (n ) * 3.0
291+
292+ # Basinhopping returns an OptimizeResult with a nested OptimizeResult
293+ result = OptimizeResult (
294+ x = np .ones (n ),
295+ lowest_optimization_result = OptimizeResult (x = np .ones (n ), hess_inv = hess_inv_inner ),
296+ )
297+
298+ with model :
299+ ds = optimizer_result_to_dataset (result , method = "basinhopping" , mu = test_point )
300+
301+ assert "hess_inv" in ds
302+ assert ds ["hess_inv" ].shape == (n , n )
303+ np .testing .assert_allclose (ds ["hess_inv" ].values , hess_inv_inner )
304+ expected_names = ["mu" , "sigma" ]
305+ assert ds ["hess_inv" ].coords ["variables" ].values .tolist () == expected_names
306+ assert ds ["hess_inv" ].coords ["variables_aux" ].values .tolist () == expected_names
0 commit comments