Skip to content

Commit e42e972

Browse files
committed
test: Add tests for jax backend
1 parent aebc00f commit e42e972

File tree

3 files changed

+64
-24
lines changed

3 files changed

+64
-24
lines changed

python/nutpie/compile_pymc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,8 @@ def _compile_pymc_model_jax(model, *, gradient_backend=None, **kwargs):
297297
orig_logp_fn = logp_fn._fun
298298

299299
@jax.jit
300-
def logp_fn_jax_grad(x):
301-
return jax.value_and_grad(lambda x: orig_logp_fn(x)[0])(x)
300+
def logp_fn_jax_grad(x, **shared):
301+
return jax.value_and_grad(lambda x: orig_logp_fn(x, **shared)[0])(x)
302302

303303
logp_fn = logp_fn_jax_grad
304304

@@ -384,12 +384,17 @@ def compile_pymc_model(
384384
"and restart your kernel in case you are in an interactive session."
385385
)
386386

387+
if backend is None:
388+
backend = "numba"
389+
387390
if backend.lower() == "numba":
388391
return _compile_pymc_model_numba(model, **kwargs)
389392
elif backend.lower() == "jax":
390393
return _compile_pymc_model_jax(
391394
model, gradient_backend=gradient_backend, **kwargs
392395
)
396+
else:
397+
raise ValueError(f"Backend must be one of numba and jax. Got {backend}")
393398

394399

395400
def _compute_shapes(model):

python/nutpie/compiled_pyfunc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def with_data(self, **updates):
3737

3838
updated = self._shared_data.copy()
3939
updated.update(**updates)
40-
return dataclasses.replace(self, shared_data=updated)
40+
return dataclasses.replace(self, _shared_data=updated)
4141

4242
def _make_sampler(self, settings, init_mean, cores, progress_type):
4343
model = self._make_model(init_mean)
@@ -49,7 +49,6 @@ def _make_sampler(self, settings, init_mean, cores, progress_type):
4949
)
5050

5151
def _make_model(self, init_mean):
52-
5352
def make_logp_func():
5453
logp_fn = self._make_logp_func()
5554
return partial(logp_fn, **self._shared_data)

tests/test_pymc.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,63 +6,87 @@
66
import nutpie.compile_pymc
77

88

9-
def test_pymc_model():
9+
parameterize_backends = pytest.mark.parametrize(
10+
"backend, gradient_backend",
11+
[("numba", None), ("jax", "pytensor"), ("jax", "jax")],
12+
)
13+
14+
15+
@parameterize_backends
16+
def test_pymc_model(backend, gradient_backend):
1017
with pm.Model() as model:
1118
pm.Normal("a")
1219

13-
compiled = nutpie.compile_pymc_model(model)
20+
compiled = nutpie.compile_pymc_model(
21+
model, backend=backend, gradient_backend=gradient_backend
22+
)
1423
trace = nutpie.sample(compiled, chains=1)
1524
trace.posterior.a # noqa: B018
1625

1726

18-
def test_blocking():
27+
@parameterize_backends
28+
def test_blocking(backend, gradient_backend):
1929
with pm.Model() as model:
2030
pm.Normal("a")
2131

22-
compiled = nutpie.compile_pymc_model(model)
32+
compiled = nutpie.compile_pymc_model(
33+
model, backend=backend, gradient_backend=gradient_backend
34+
)
2335
sampler = nutpie.sample(compiled, chains=1, blocking=False)
2436
trace = sampler.wait()
2537
trace.posterior.a # noqa: B018
2638

2739

40+
@parameterize_backends
2841
@pytest.mark.timeout(2)
29-
def test_wait_timeout():
42+
def test_wait_timeout(backend, gradient_backend):
3043
with pm.Model() as model:
3144
pm.Normal("a", shape=100_000)
32-
compiled = nutpie.compile_pymc_model(model)
45+
compiled = nutpie.compile_pymc_model(
46+
model, backend=backend, gradient_backend=gradient_backend
47+
)
3348
sampler = nutpie.sample(compiled, chains=1, blocking=False)
3449
with pytest.raises(TimeoutError):
3550
sampler.wait(timeout=0.1)
3651
sampler.cancel()
3752

3853

54+
@parameterize_backends
3955
@pytest.mark.timeout(2)
40-
def test_pause():
56+
def test_pause(backend, gradient_backend):
4157
with pm.Model() as model:
4258
pm.Normal("a", shape=100_000)
43-
compiled = nutpie.compile_pymc_model(model)
59+
compiled = nutpie.compile_pymc_model(
60+
model, backend=backend, gradient_backend=gradient_backend
61+
)
4462
sampler = nutpie.sample(compiled, chains=1, blocking=False)
4563
sampler.pause()
4664
sampler.resume()
4765
sampler.cancel()
4866

4967

50-
def test_pymc_model_with_coordinate():
68+
@parameterize_backends
69+
def test_pymc_model_with_coordinate(backend, gradient_backend):
5170
with pm.Model() as model:
5271
model.add_coord("foo", length=5)
5372
pm.Normal("a", dims="foo")
5473

55-
compiled = nutpie.compile_pymc_model(model)
74+
compiled = nutpie.compile_pymc_model(
75+
model, backend=backend, gradient_backend=gradient_backend
76+
)
5677
trace = nutpie.sample(compiled, chains=1)
5778
trace.posterior.a # noqa: B018
5879

5980

60-
def test_pymc_model_store_extra():
81+
@parameterize_backends
82+
def test_pymc_model_store_extra(backend, gradient_backend):
6183
with pm.Model() as model:
6284
model.add_coord("foo", length=5)
6385
pm.Normal("a", dims="foo")
6486

65-
compiled = nutpie.compile_pymc_model(model)
87+
compiled = nutpie.compile_pymc_model(
88+
model, backend=backend, gradient_backend=gradient_backend
89+
)
6690
trace = nutpie.sample(
6791
compiled,
6892
chains=1,
@@ -78,33 +102,42 @@ def test_pymc_model_store_extra():
78102
_ = trace.sample_stats.mass_matrix_inv
79103

80104

81-
def test_trafo():
105+
@parameterize_backends
106+
def test_trafo(backend, gradient_backend):
82107
with pm.Model() as model:
83108
pm.Uniform("a")
84109

85-
compiled = nutpie.compile_pymc_model(model)
110+
compiled = nutpie.compile_pymc_model(
111+
model, backend=backend, gradient_backend=gradient_backend
112+
)
86113
trace = nutpie.sample(compiled, chains=1)
87114
trace.posterior.a # noqa: B018
88115

89116

90-
def test_det():
117+
@parameterize_backends
118+
def test_det(backend, gradient_backend):
91119
with pm.Model() as model:
92120
a = pm.Uniform("a", shape=2)
93121
pm.Deterministic("b", 2 * a)
94122

95-
compiled = nutpie.compile_pymc_model(model)
123+
compiled = nutpie.compile_pymc_model(
124+
model, backend=backend, gradient_backend=gradient_backend
125+
)
96126
trace = nutpie.sample(compiled, chains=1)
97127
assert trace.posterior.a.shape[-1] == 2
98128
assert trace.posterior.b.shape[-1] == 2
99129

100130

101-
def test_pymc_model_shared():
131+
@parameterize_backends
132+
def test_pymc_model_shared(backend, gradient_backend):
102133
with pm.Model() as model:
103134
mu = pm.MutableData("mu", 0.1)
104135
sigma = pm.MutableData("sigma", np.ones(3))
105136
pm.Normal("a", mu=mu, sigma=sigma, shape=3)
106137

107-
compiled = nutpie.compile_pymc_model(model)
138+
compiled = nutpie.compile_pymc_model(
139+
model, backend=backend, gradient_backend=gradient_backend
140+
)
108141
trace = nutpie.sample(compiled, chains=1, seed=1)
109142
np.testing.assert_allclose(trace.posterior.a.mean().values, 0.1, atol=0.05)
110143

@@ -117,13 +150,16 @@ def test_pymc_model_shared():
117150
nutpie.sample(compiled3, chains=1)
118151

119152

120-
def test_missing():
153+
@parameterize_backends
154+
def test_missing(backend, gradient_backend):
121155
with pm.Model(coords={"obs": range(4)}) as model:
122156
mu = pm.Normal("mu")
123157
y = pm.Normal("y", mu, observed=[0, -1, 1, np.nan], dims="obs")
124158
pm.Deterministic("y2", 2 * y, dims="obs")
125159

126-
compiled = nutpie.compile_pymc_model(model)
160+
compiled = nutpie.compile_pymc_model(
161+
model, backend=backend, gradient_backend=gradient_backend
162+
)
127163
tr = nutpie.sample(compiled, chains=1, seed=1)
128164
print(tr.posterior)
129165
assert hasattr(tr.posterior, "y_unobserved")

0 commit comments

Comments
 (0)