@@ -65,20 +65,20 @@ def compute_z(x):
65
65
],
66
66
)
67
67
def test_JAX_map (method , use_grad , use_hess , rng ):
68
- with pm .Model () as m :
69
- mu = pm .Normal ("mu" )
70
- sigma = pm .Exponential ("sigma" , 1 )
71
- pm .Normal ("y_hat" , mu = mu , sigma = sigma , observed = rng .normal (loc = 3 , scale = 1.5 , size = 100 ))
72
-
73
68
extra_kwargs = {}
74
69
if method == "dogleg" :
75
70
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
76
71
# where this is true
77
72
extra_kwargs = {"initvals" : {"mu" : 2 , "sigma_log__" : 1 }}
78
73
79
- optimized_point = find_MAP (
80
- m , method , ** extra_kwargs , use_grad = use_grad , use_hess = use_hess , progressbar = False
81
- )
74
+ with pm .Model () as m :
75
+ mu = pm .Normal ("mu" )
76
+ sigma = pm .Exponential ("sigma" , 1 )
77
+ pm .Normal ("y_hat" , mu = mu , sigma = sigma , observed = rng .normal (loc = 3 , scale = 1.5 , size = 100 ))
78
+
79
+ optimized_point = find_MAP (
80
+ method = method , ** extra_kwargs , use_grad = use_grad , use_hess = use_hess , progressbar = False
81
+ )
82
82
mu_hat , log_sigma_hat = optimized_point ["mu" ], optimized_point ["sigma_log__" ]
83
83
84
84
assert np .isclose (mu_hat , 3 , atol = 0.5 )
@@ -102,12 +102,12 @@ def test_fit_laplace_coords(rng, transform_samples):
102
102
observed = rng .normal (loc = 3 , scale = 1.5 , size = (100 , 3 )),
103
103
dims = ["obs_idx" , "city" ],
104
104
)
105
- optimized_point = find_MAP (
106
- model ,
107
- "Newton-CG" ,
108
- use_grad = True ,
109
- progressbar = False ,
110
- )
105
+
106
+ optimized_point = find_MAP (
107
+ method = "Newton-CG" ,
108
+ use_grad = True ,
109
+ progressbar = False ,
110
+ )
111
111
112
112
for value in optimized_point .values ():
113
113
assert value .shape == (3 ,)
@@ -145,9 +145,9 @@ def test_fit_laplace_ragged_coords(rng):
145
145
dims = ["obs_idx" , "city" ],
146
146
)
147
147
148
- optimized_point , _ = find_MAP (
149
- ragged_dim_model , "Newton-CG" , use_grad = True , progressbar = False , return_raw = True
150
- )
148
+ optimized_point , _ = find_MAP (
149
+ method = "Newton-CG" , use_grad = True , progressbar = False , return_raw = True
150
+ )
151
151
152
152
idata = fit_laplace (optimized_point , ragged_dim_model , progressbar = False )
153
153
@@ -176,12 +176,12 @@ def test_fit_laplace(transform_samples):
176
176
observed = np .random .default_rng ().normal (loc = 3 , scale = 1.5 , size = (10000 ,)),
177
177
)
178
178
179
- optimized_point = find_MAP (
180
- simp_model ,
181
- "Newton-CG" ,
182
- use_grad = True ,
183
- progressbar = False ,
184
- )
179
+ optimized_point = find_MAP (
180
+ method = "Newton-CG" ,
181
+ use_grad = True ,
182
+ progressbar = False ,
183
+ )
184
+
185
185
idata = fit_laplace (
186
186
optimized_point , simp_model , transform_samples = transform_samples , progressbar = False
187
187
)
0 commit comments