File tree Expand file tree Collapse file tree 1 file changed +40
-0
lines changed Expand file tree Collapse file tree 1 file changed +40
-0
lines changed Original file line number Diff line number Diff line change
1
+ """Test file created for the sole purpose of tracking the status of Numba compilation"""
2
+ import aesara
3
+ import aesara .tensor as at
4
+ from aeppl import joint_logprob
5
+
6
+ import aehmc .nuts as nuts
7
+
8
+
9
+ def test_sample_with_numba ():
10
+
11
+ srng = at .random .RandomStream (seed = 0 )
12
+ Y_rv = srng .normal (1 , 2 )
13
+
14
+ def logprob_fn (y ):
15
+ logprob = joint_logprob ({Y_rv : y })
16
+ return logprob
17
+
18
+ # Build the transition kernel
19
+ kernel = nuts .new_kernel (srng , logprob_fn )
20
+
21
+ # Compile a function that updates the chain
22
+ y_vv = Y_rv .clone ()
23
+ initial_state = nuts .new_state (y_vv , logprob_fn )
24
+
25
+ step_size = at .as_tensor (1e-2 )
26
+ inverse_mass_matrix = at .as_tensor (1.0 )
27
+ (
28
+ next_state ,
29
+ potential_energy ,
30
+ potential_energy_grad ,
31
+ acceptance_prob ,
32
+ num_doublings ,
33
+ is_turning ,
34
+ is_diverging ,
35
+ ), updates = kernel (* initial_state , step_size , inverse_mass_matrix )
36
+
37
+ next_step_fn = aesara .function ([y_vv ], next_state , updates = updates , mode = "NUMBA" )
38
+
39
+ # TODO: Assert something
40
+ next_step_fn (Y_rv .eval ())
You can’t perform that action at this time.
0 commit comments