Skip to content

Commit 3d619cf

Browse files
committed
test: add low rank tests
1 parent c452882 commit 3d619cf

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

tests/test_pymc.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,32 @@ def test_pymc_model(backend, gradient_backend):
3131
trace.posterior.a # noqa: B018
3232

3333

34+
@pytest.mark.pymc
35+
@parameterize_backends
36+
def test_low_rank(backend, gradient_backend):
37+
with pm.Model() as model:
38+
pm.Normal("a")
39+
40+
compiled = nutpie.compile_pymc_model(
41+
model, backend=backend, gradient_backend=gradient_backend
42+
)
43+
trace = nutpie.sample(compiled, chains=1, low_rank_modified_mass_matrix=True)
44+
trace.posterior.a # noqa: B018
45+
46+
47+
@pytest.mark.pymc
48+
@parameterize_backends
49+
def test_low_rank_half_normal(backend, gradient_backend):
50+
with pm.Model() as model:
51+
pm.HalfNormal("a", shape=13)
52+
53+
compiled = nutpie.compile_pymc_model(
54+
model, backend=backend, gradient_backend=gradient_backend
55+
)
56+
trace = nutpie.sample(compiled, chains=1, low_rank_modified_mass_matrix=True)
57+
trace.posterior.a # noqa: B018
58+
59+
3460
@pytest.mark.pymc
3561
@parameterize_backends
3662
def test_zero_size(backend, gradient_backend):

tests/test_stan.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,23 @@ def test_stan_model():
2727
trace.posterior.a # noqa: B018
2828

2929

30+
@pytest.mark.stan
31+
def test_stan_model_low_rank():
32+
model = """
33+
data {}
34+
parameters {
35+
real a;
36+
}
37+
model {
38+
a ~ normal(0, 1);
39+
}
40+
"""
41+
42+
compiled_model = nutpie.compile_stan_model(code=model)
43+
trace = nutpie.sample(compiled_model, low_rank_modified_mass_matrix=True)
44+
trace.posterior.a # noqa: B018
45+
46+
3047
@pytest.mark.stan
3148
def test_empty():
3249
model = """
@@ -40,7 +57,7 @@ def test_empty():
4057
"""
4158

4259
compiled_model = nutpie.compile_stan_model(code=model)
43-
trace = nutpie.sample(compiled_model) # noqa: F841
60+
nutpie.sample(compiled_model)
4461
# TODO: Variable `a` is missing because of this bridgestan issue:
4562
# https://github.com/roualdes/bridgestan/issues/278
4663
# assert trace.posterior.a.shape == (0, 1000)

0 commit comments

Comments
 (0)