Skip to content

Commit 4d88343

Browse files
Update tests to match new signature
1 parent a81079b commit 4d88343

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/test_laplace.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,11 @@ def test_fit_laplace_ragged_coords(rng):
202202

203203

204204
@pytest.mark.parametrize(
205-
"transform_samples",
205+
"fit_in_unconstrained_space",
206206
[True, False],
207207
ids=["transformed", "untransformed"],
208208
)
209-
def test_fit_laplace(transform_samples):
209+
def test_fit_laplace(fit_in_unconstrained_space):
210210
with pm.Model() as simp_model:
211211
mu = pm.Normal("mu", mu=3, sigma=0.5)
212212
sigma = pm.Exponential("sigma", 1)
@@ -221,7 +221,7 @@ def test_fit_laplace(transform_samples):
221221
optimize_method="trust-ncg",
222222
use_grad=True,
223223
use_hessp=True,
224-
transform_samples=transform_samples,
224+
fit_in_unconstrained_space=fit_in_unconstrained_space,
225225
optimizer_kwargs=dict(maxiter=100_000, tol=1e-100),
226226
)
227227

@@ -230,7 +230,7 @@ def test_fit_laplace(transform_samples):
230230
np.mean(idata.posterior.sigma, axis=1), np.full((2,), 1.5), atol=0.1
231231
)
232232

233-
if transform_samples:
233+
if fit_in_unconstrained_space:
234234
assert idata.fit.rows.values.tolist() == ["mu", "sigma_log__"]
235235
np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 0.4]), atol=0.1)
236236
else:

0 commit comments

Comments
 (0)