Skip to content

Commit 3c7051b

Browse files
rloufbrandonwillard
authored andcommitted
Add a simple test for numba compilation
1 parent c1f8451 commit 3c7051b

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/test_numba.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Test file created for the sole purpose of tracking the status of Numba compilation"""
2+
import aesara
3+
import aesara.tensor as at
4+
from aeppl import joint_logprob
5+
6+
import aehmc.nuts as nuts
7+
8+
9+
def test_sample_with_numba():
10+
11+
srng = at.random.RandomStream(seed=0)
12+
Y_rv = srng.normal(1, 2)
13+
14+
def logprob_fn(y):
15+
logprob = joint_logprob({Y_rv: y})
16+
return logprob
17+
18+
# Build the transition kernel
19+
kernel = nuts.new_kernel(srng, logprob_fn)
20+
21+
# Compile a function that updates the chain
22+
y_vv = Y_rv.clone()
23+
initial_state = nuts.new_state(y_vv, logprob_fn)
24+
25+
step_size = at.as_tensor(1e-2)
26+
inverse_mass_matrix = at.as_tensor(1.0)
27+
(
28+
next_state,
29+
potential_energy,
30+
potential_energy_grad,
31+
acceptance_prob,
32+
num_doublings,
33+
is_turning,
34+
is_diverging,
35+
), updates = kernel(*initial_state, step_size, inverse_mass_matrix)
36+
37+
next_step_fn = aesara.function([y_vv], next_state, updates=updates, mode="NUMBA")
38+
39+
# TODO: Assert something
40+
next_step_fn(Y_rv.eval())

0 commit comments

Comments
 (0)