Skip to content

Commit bfd35fb

Browse files
committed
Add a simple test for numba compilation
1 parent dde2d83 commit bfd35fb

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/test_numba.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Test file created for the sole purpose of tracking the status of Numba compilation"""
2+
import aesara
3+
import aesara.tensor as at
4+
from aeppl import joint_logprob
5+
import aehmc.nuts as nuts
6+
7+
8+
def test_sample_with_numba():
9+
10+
srng = at.random.RandomStream(seed=0)
11+
Y_rv = srng.normal(1, 2)
12+
13+
def logprob_fn(y):
14+
logprob = joint_logprob({Y_rv: y})
15+
return logprob
16+
17+
# Build the transition kernel
18+
kernel = nuts.new_kernel(srng, logprob_fn)
19+
20+
# Compile a function that updates the chain
21+
y_vv = Y_rv.clone()
22+
initial_state = nuts.new_state(y_vv, logprob_fn)
23+
24+
step_size = at.as_tensor(1e-2)
25+
inverse_mass_matrix = at.as_tensor(1.0)
26+
(
27+
next_state,
28+
potential_energy,
29+
potential_energy_grad,
30+
acceptance_prob,
31+
num_doublings,
32+
is_turning,
33+
is_diverging,
34+
), updates = kernel(*initial_state, step_size, inverse_mass_matrix)
35+
36+
next_step_fn = aesara.function([y_vv], next_state, updates=updates, mode='NUMBA')

0 commit comments

Comments
 (0)