Skip to content

Commit 4e2280b

Browse files
committed
test: mark tests as pymc or stan and select in ci
1 parent d8a5ce2 commit 4e2280b

File tree

6 files changed

+105
-14
lines changed

6 files changed

+105
-14
lines changed

.github/workflows/ci.yml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,14 @@ jobs:
6767
set -e
6868
python3 -m venv .venv
6969
source .venv/bin/activate
70+
uv pip install 'nutpie[stan]' --find-links dist --force-reinstall
71+
uv pip install pytest pytest-timeout
72+
pytest -m "stan and not flow"
73+
uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall
74+
uv pip install jax
75+
pytest -m "pymc and not flow"
7076
uv pip install 'nutpie[all]' --find-links dist --force-reinstall
71-
uv pip install pytest
72-
pytest
77+
pytest -m flow
7378
- name: pytest
7479
if: ${{ !startsWith(matrix.platform.target, 'x86') && matrix.platform.target != 'ppc64' }}
7580
uses: uraimo/run-on-arch-action@v3
@@ -85,9 +90,9 @@ jobs:
8590
run: |
8691
set -e
8792
source $HOME/.local/bin/env
88-
uv pip install --system -U pip pytest
93+
uv pip install --system -U pip pytest pytest-timeout
8994
uv pip install --system 'nutpie[all]' --find-links dist --force-reinstall
90-
pytest -m "not slow"
95+
pytest -m "not flow" # Skip flow tests, they are slow on emulated platforms
9196
9297
# pyarrow doesn't currently seem to work on musllinux
9398
#musllinux:
@@ -208,7 +213,7 @@ jobs:
208213
python3 -m venv .venv
209214
source .venv/Scripts/activate
210215
uv pip install "nutpie[all]" --find-links dist --force-reinstall
211-
uv pip install pytest
216+
uv pip install pytest pytest-timeout
212217
pytest
213218
214219
macos:
@@ -252,8 +257,8 @@ jobs:
252257
python3 -m venv .venv
253258
source .venv/bin/activate
254259
uv pip install 'nutpie[all]' --find-links dist --force-reinstall
255-
uv pip install pytest
256-
pytest
260+
uv pip install pytest pytest-timeout
261+
pytest -m "not (flow and stan)" # The stan tests seem to run out of memory on macOS?
257262
258263
sdist:
259264
runs-on: ubuntu-latest

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dev = [
3535
"jax >= 0.4.27",
3636
"flowjax >= 17.0.2",
3737
"pytest",
38+
"pytest-timeout",
3839
]
3940
all = [
4041
"bridgestan >= 2.6.1",
@@ -65,3 +66,10 @@ venv = "default"
6566
module-name = "nutpie._lib"
6667
python-source = "python"
6768
features = ["pyo3/extension-module"]
69+
70+
[tool.pytest.ini_options]
71+
markers = [
72+
"flow: tests for normalizing flows",
73+
"stan: tests for Stan models",
74+
"pymc: tests for PyMC models",
75+
]

python/nutpie/compile_pymc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import numpy as np
1313
import pandas as pd
1414
from numpy.typing import NDArray
15-
from pymc.initial_point import make_initial_point_fn
1615

1716
from nutpie import _lib
1817
from nutpie.compiled_pyfunc import SeedType, from_pyfunc
@@ -501,6 +500,7 @@ def compile_pymc_model(
501500
)
502501

503502
from pymc.model.transform.optimization import freeze_dims_and_data
503+
from pymc.initial_point import make_initial_point_fn
504504

505505
if freeze_model is None:
506506
freeze_model = backend == "jax"

tests/test_pymc.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
from importlib.util import find_spec
2+
import time
3+
import pytest
4+
5+
if find_spec("pymc") is None:
6+
pytest.skip("Skip pymc tests", allow_module_level=True)
7+
18
import numpy as np
29
import pymc as pm
310
import pytest
@@ -12,6 +19,7 @@
1219
)
1320

1421

22+
@pytest.mark.pymc
1523
@parameterize_backends
1624
def test_pymc_model(backend, gradient_backend):
1725
with pm.Model() as model:
@@ -24,6 +32,7 @@ def test_pymc_model(backend, gradient_backend):
2432
trace.posterior.a # noqa: B018
2533

2634

35+
@pytest.mark.pymc
2736
@parameterize_backends
2837
def test_pymc_model_float32(backend, gradient_backend):
2938
import pytensor
@@ -39,6 +48,7 @@ def test_pymc_model_float32(backend, gradient_backend):
3948
trace.posterior.a # noqa: B018
4049

4150

51+
@pytest.mark.pymc
4252
@parameterize_backends
4353
def test_pymc_model_no_prior(backend, gradient_backend):
4454
with pm.Model() as model:
@@ -52,6 +62,7 @@ def test_pymc_model_no_prior(backend, gradient_backend):
5262
trace.posterior.a # noqa: B018
5363

5464

65+
@pytest.mark.pymc
5566
@parameterize_backends
5667
def test_blocking(backend, gradient_backend):
5768
with pm.Model() as model:
@@ -65,34 +76,41 @@ def test_blocking(backend, gradient_backend):
6576
trace.posterior.a # noqa: B018
6677

6778

79+
@pytest.mark.pymc
6880
@parameterize_backends
69-
@pytest.mark.timeout(2)
81+
@pytest.mark.timeout(20)
7082
def test_wait_timeout(backend, gradient_backend):
7183
with pm.Model() as model:
7284
pm.Normal("a", shape=100_000)
7385
compiled = nutpie.compile_pymc_model(
7486
model, backend=backend, gradient_backend=gradient_backend
7587
)
88+
start = time.time()
7689
sampler = nutpie.sample(compiled, chains=1, blocking=False)
7790
with pytest.raises(TimeoutError):
7891
sampler.wait(timeout=0.1)
7992
sampler.cancel()
93+
assert start - time.time() < 5
8094

8195

96+
@pytest.mark.pymc
8297
@parameterize_backends
83-
@pytest.mark.timeout(2)
98+
@pytest.mark.timeout(20)
8499
def test_pause(backend, gradient_backend):
85100
with pm.Model() as model:
86101
pm.Normal("a", shape=100_000)
87102
compiled = nutpie.compile_pymc_model(
88103
model, backend=backend, gradient_backend=gradient_backend
89104
)
105+
start = time.time()
90106
sampler = nutpie.sample(compiled, chains=1, blocking=False)
91107
sampler.pause()
92108
sampler.resume()
93109
sampler.cancel()
110+
assert start - time.time() < 5
94111

95112

113+
@pytest.mark.pymc
96114
@parameterize_backends
97115
def test_pymc_model_with_coordinate(backend, gradient_backend):
98116
with pm.Model() as model:
@@ -106,6 +124,7 @@ def test_pymc_model_with_coordinate(backend, gradient_backend):
106124
trace.posterior.a # noqa: B018
107125

108126

127+
@pytest.mark.pymc
109128
@parameterize_backends
110129
def test_pymc_model_store_extra(backend, gradient_backend):
111130
with pm.Model() as model:
@@ -130,6 +149,7 @@ def test_pymc_model_store_extra(backend, gradient_backend):
130149
_ = trace.sample_stats.mass_matrix_inv
131150

132151

152+
@pytest.mark.pymc
133153
@parameterize_backends
134154
def test_trafo(backend, gradient_backend):
135155
with pm.Model() as model:
@@ -142,6 +162,7 @@ def test_trafo(backend, gradient_backend):
142162
trace.posterior.a # noqa: B018
143163

144164

165+
@pytest.mark.pymc
145166
@parameterize_backends
146167
def test_det(backend, gradient_backend):
147168
with pm.Model() as model:
@@ -156,6 +177,7 @@ def test_det(backend, gradient_backend):
156177
assert trace.posterior.b.shape[-1] == 2
157178

158179

180+
@pytest.mark.pymc
159181
@parameterize_backends
160182
def test_non_identifier_names(backend, gradient_backend):
161183
with pm.Model() as model:
@@ -172,6 +194,7 @@ def test_non_identifier_names(backend, gradient_backend):
172194
assert trace.posterior["foo::b"].shape[-1] == 2
173195

174196

197+
@pytest.mark.pymc
175198
@parameterize_backends
176199
def test_pymc_model_shared(backend, gradient_backend):
177200
with pm.Model() as model:
@@ -197,6 +220,7 @@ def test_pymc_model_shared(backend, gradient_backend):
197220
nutpie.sample(compiled3, chains=1)
198221

199222

223+
@pytest.mark.pymc
200224
@parameterize_backends
201225
def test_pymc_var_names(backend, gradient_backend):
202226
with pm.Model() as model:
@@ -244,7 +268,8 @@ def test_pymc_var_names(backend, gradient_backend):
244268
assert not hasattr(trace.posterior, "c")
245269

246270

247-
@pytest.mark.slow
271+
@pytest.mark.pymc
272+
@pytest.mark.flow
248273
def test_normalizing_flow():
249274
with pm.Model() as model:
250275
pm.HalfNormal("x", shape=2)
@@ -272,6 +297,7 @@ def test_normalizing_flow():
272297
assert kstest.pvalue > 0.01
273298

274299

300+
@pytest.mark.pymc
275301
@pytest.mark.parametrize(
276302
("backend", "gradient_backend"),
277303
[

tests/test_stan.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
from importlib.util import find_spec
2+
import pytest
3+
4+
if find_spec("bridgestan") is None:
5+
pytest.skip("Skip stan tests", allow_module_level=True)
6+
17
import numpy as np
28
import pytest
39

410
import nutpie
511

612

13+
@pytest.mark.stan
714
def test_stan_model():
815
model = """
916
data {}
@@ -20,6 +27,7 @@ def test_stan_model():
2027
trace.posterior.a # noqa: B018
2128

2229

30+
@pytest.mark.stan
2331
def test_stan_model_data():
2432
model = """
2533
data {
@@ -40,11 +48,55 @@ def test_stan_model_data():
4048
trace.posterior.a # noqa: B018
4149

4250

43-
@pytest.mark.slow
44-
def test_stan_flow():
51+
@pytest.mark.stan
52+
def test_stan_memory_order():
4553
model = """
54+
data {
55+
real x;
56+
}
4657
parameters {
4758
real a;
59+
}
60+
model {
61+
a ~ normal(0, 1);
62+
}
63+
generated quantities {
64+
array[2, 3] matrix[5, 7] b;
65+
real count = 0;
66+
for (i in 1:2)
67+
for (j in 1:3) {
68+
for (k in 1:5) {
69+
for (n in 1:7) {
70+
b[i, j][k, n] = count;
71+
count = count + 1;
72+
}
73+
}
74+
}
75+
}
76+
"""
77+
78+
compiled_model = nutpie.compile_stan_model(code=model)
79+
with pytest.raises(RuntimeError):
80+
trace = nutpie.sample(compiled_model)
81+
trace = nutpie.sample(compiled_model.with_data(x=np.array(3.0)))
82+
trace.posterior.a # noqa: B018
83+
assert trace.posterior.b.shape == (6, 1000, 2, 3, 5, 7)
84+
b = trace.posterior.b.isel(chain=0, draw=0)
85+
count = 0
86+
for i in range(2):
87+
for j in range(3):
88+
for k in range(5):
89+
for n in range(7):
90+
assert float(b[i, j, k, n]) == count
91+
count += 1
92+
93+
94+
@pytest.mark.flow
95+
@pytest.mark.stan
96+
def test_stan_flow():
97+
model = """
98+
parameters {
99+
array[5] real a;
48100
real<lower=0> b;
49101
}
50102
model {

0 commit comments

Comments
 (0)