6
6
import nutpie .compile_pymc
7
7
8
8
9
- def test_pymc_model ():
9
+ parameterize_backends = pytest .mark .parametrize (
10
+ "backend, gradient_backend" ,
11
+ [("numba" , None ), ("jax" , "pytensor" ), ("jax" , "jax" )],
12
+ )
13
+
14
+
15
+ @parameterize_backends
16
+ def test_pymc_model (backend , gradient_backend ):
10
17
with pm .Model () as model :
11
18
pm .Normal ("a" )
12
19
13
- compiled = nutpie .compile_pymc_model (model )
20
+ compiled = nutpie .compile_pymc_model (
21
+ model , backend = backend , gradient_backend = gradient_backend
22
+ )
14
23
trace = nutpie .sample (compiled , chains = 1 )
15
24
trace .posterior .a # noqa: B018
16
25
17
26
18
- def test_blocking ():
27
+ @parameterize_backends
28
+ def test_blocking (backend , gradient_backend ):
19
29
with pm .Model () as model :
20
30
pm .Normal ("a" )
21
31
22
- compiled = nutpie .compile_pymc_model (model )
32
+ compiled = nutpie .compile_pymc_model (
33
+ model , backend = backend , gradient_backend = gradient_backend
34
+ )
23
35
sampler = nutpie .sample (compiled , chains = 1 , blocking = False )
24
36
trace = sampler .wait ()
25
37
trace .posterior .a # noqa: B018
26
38
27
39
40
+ @parameterize_backends
28
41
@pytest .mark .timeout (2 )
29
- def test_wait_timeout ():
42
+ def test_wait_timeout (backend , gradient_backend ):
30
43
with pm .Model () as model :
31
44
pm .Normal ("a" , shape = 100_000 )
32
- compiled = nutpie .compile_pymc_model (model )
45
+ compiled = nutpie .compile_pymc_model (
46
+ model , backend = backend , gradient_backend = gradient_backend
47
+ )
33
48
sampler = nutpie .sample (compiled , chains = 1 , blocking = False )
34
49
with pytest .raises (TimeoutError ):
35
50
sampler .wait (timeout = 0.1 )
36
51
sampler .cancel ()
37
52
38
53
54
+ @parameterize_backends
39
55
@pytest .mark .timeout (2 )
40
- def test_pause ():
56
+ def test_pause (backend , gradient_backend ):
41
57
with pm .Model () as model :
42
58
pm .Normal ("a" , shape = 100_000 )
43
- compiled = nutpie .compile_pymc_model (model )
59
+ compiled = nutpie .compile_pymc_model (
60
+ model , backend = backend , gradient_backend = gradient_backend
61
+ )
44
62
sampler = nutpie .sample (compiled , chains = 1 , blocking = False )
45
63
sampler .pause ()
46
64
sampler .resume ()
47
65
sampler .cancel ()
48
66
49
67
50
- def test_pymc_model_with_coordinate ():
68
+ @parameterize_backends
69
+ def test_pymc_model_with_coordinate (backend , gradient_backend ):
51
70
with pm .Model () as model :
52
71
model .add_coord ("foo" , length = 5 )
53
72
pm .Normal ("a" , dims = "foo" )
54
73
55
- compiled = nutpie .compile_pymc_model (model )
74
+ compiled = nutpie .compile_pymc_model (
75
+ model , backend = backend , gradient_backend = gradient_backend
76
+ )
56
77
trace = nutpie .sample (compiled , chains = 1 )
57
78
trace .posterior .a # noqa: B018
58
79
59
80
60
- def test_pymc_model_store_extra ():
81
+ @parameterize_backends
82
+ def test_pymc_model_store_extra (backend , gradient_backend ):
61
83
with pm .Model () as model :
62
84
model .add_coord ("foo" , length = 5 )
63
85
pm .Normal ("a" , dims = "foo" )
64
86
65
- compiled = nutpie .compile_pymc_model (model )
87
+ compiled = nutpie .compile_pymc_model (
88
+ model , backend = backend , gradient_backend = gradient_backend
89
+ )
66
90
trace = nutpie .sample (
67
91
compiled ,
68
92
chains = 1 ,
@@ -78,33 +102,42 @@ def test_pymc_model_store_extra():
78
102
_ = trace .sample_stats .mass_matrix_inv
79
103
80
104
81
- def test_trafo ():
105
+ @parameterize_backends
106
+ def test_trafo (backend , gradient_backend ):
82
107
with pm .Model () as model :
83
108
pm .Uniform ("a" )
84
109
85
- compiled = nutpie .compile_pymc_model (model )
110
+ compiled = nutpie .compile_pymc_model (
111
+ model , backend = backend , gradient_backend = gradient_backend
112
+ )
86
113
trace = nutpie .sample (compiled , chains = 1 )
87
114
trace .posterior .a # noqa: B018
88
115
89
116
90
- def test_det ():
117
+ @parameterize_backends
118
+ def test_det (backend , gradient_backend ):
91
119
with pm .Model () as model :
92
120
a = pm .Uniform ("a" , shape = 2 )
93
121
pm .Deterministic ("b" , 2 * a )
94
122
95
- compiled = nutpie .compile_pymc_model (model )
123
+ compiled = nutpie .compile_pymc_model (
124
+ model , backend = backend , gradient_backend = gradient_backend
125
+ )
96
126
trace = nutpie .sample (compiled , chains = 1 )
97
127
assert trace .posterior .a .shape [- 1 ] == 2
98
128
assert trace .posterior .b .shape [- 1 ] == 2
99
129
100
130
101
- def test_pymc_model_shared ():
131
+ @parameterize_backends
132
+ def test_pymc_model_shared (backend , gradient_backend ):
102
133
with pm .Model () as model :
103
134
mu = pm .MutableData ("mu" , 0.1 )
104
135
sigma = pm .MutableData ("sigma" , np .ones (3 ))
105
136
pm .Normal ("a" , mu = mu , sigma = sigma , shape = 3 )
106
137
107
- compiled = nutpie .compile_pymc_model (model )
138
+ compiled = nutpie .compile_pymc_model (
139
+ model , backend = backend , gradient_backend = gradient_backend
140
+ )
108
141
trace = nutpie .sample (compiled , chains = 1 , seed = 1 )
109
142
np .testing .assert_allclose (trace .posterior .a .mean ().values , 0.1 , atol = 0.05 )
110
143
@@ -117,13 +150,16 @@ def test_pymc_model_shared():
117
150
nutpie .sample (compiled3 , chains = 1 )
118
151
119
152
120
- def test_missing ():
153
+ @parameterize_backends
154
+ def test_missing (backend , gradient_backend ):
121
155
with pm .Model (coords = {"obs" : range (4 )}) as model :
122
156
mu = pm .Normal ("mu" )
123
157
y = pm .Normal ("y" , mu , observed = [0 , - 1 , 1 , np .nan ], dims = "obs" )
124
158
pm .Deterministic ("y2" , 2 * y , dims = "obs" )
125
159
126
- compiled = nutpie .compile_pymc_model (model )
160
+ compiled = nutpie .compile_pymc_model (
161
+ model , backend = backend , gradient_backend = gradient_backend
162
+ )
127
163
tr = nutpie .sample (compiled , chains = 1 , seed = 1 )
128
164
print (tr .posterior )
129
165
assert hasattr (tr .posterior , "y_unobserved" )
0 commit comments