1616import numpy .testing as npt
1717import pytest
1818
19- from pymc import Data , Model , Normal , sample
19+ from pymc import Data , Model , Normal , modelcontext , sample
2020
2121
22- @pytest .mark .parametrize ("nuts_sampler" , ["pymc" , "nutpie" , "blackjax" , "numpyro" ])
23- def test_external_nuts_sampler (recwarn , nuts_sampler ):
24- if nuts_sampler != "pymc" :
25- pytest .importorskip (nuts_sampler )
26-
27- with Model ():
28- x = Normal ("x" , 100 , 5 )
29- y = Data ("y" , [1 , 2 , 3 , 4 ])
30- Data ("z" , [100 , 190 , 310 , 405 ])
31-
32- Normal ("L" , mu = x , sigma = 0.1 , observed = y )
33-
34- kwargs = {
35- "nuts_sampler" : nuts_sampler ,
36- "random_seed" : 123 ,
37- "chains" : 2 ,
38- "tune" : 500 ,
39- "draws" : 500 ,
40- "progressbar" : False ,
41- "initvals" : {"x" : 0.0 },
42- }
43-
44- idata1 = sample (** kwargs )
45- idata2 = sample (** kwargs )
22+ def check_external_sampler_output (warns , idata1 , idata2 , sample_kwargs ):
23+ nuts_sampler = sample_kwargs ["nuts_sampler" ]
24+ reference_kwargs = sample_kwargs .copy ()
25+ reference_kwargs ["nuts_sampler" ] = "pymc"
4626
47- reference_kwargs = kwargs .copy ()
48- reference_kwargs ["nuts_sampler" ] = "pymc"
27+ with modelcontext (None ):
4928 idata_reference = sample (** reference_kwargs )
5029
51- warns = {
52- (warn .category , warn .message .args [0 ])
53- for warn in recwarn
54- if warn .category not in (FutureWarning , DeprecationWarning , RuntimeWarning )
55- }
5630 expected = set ()
57- if nuts_sampler == "nutpie" :
31+ if nuts_sampler . startswith ( "nutpie" ) :
5832 expected .add (
5933 (
6034 UserWarning ,
@@ -74,7 +48,83 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
7448 assert idata_reference .posterior .attrs .keys () == idata1 .posterior .attrs .keys ()
7549
7650
51+ @pytest .fixture
52+ def pymc_model ():
53+ with Model () as m :
54+ x = Normal ("x" , 100 , 5 )
55+ y = Data ("y" , [1 , 2 , 3 , 4 ])
56+ Data ("z" , [100 , 190 , 310 , 405 ])
57+
58+ Normal ("L" , mu = x , sigma = 0.1 , observed = y )
59+
60+ return m
61+
62+
63+ @pytest .mark .parametrize ("nuts_sampler" , ["pymc" , "nutpie" , "blackjax" , "numpyro" ])
64+ def test_external_nuts_sampler (pymc_model , recwarn , nuts_sampler ):
65+ if nuts_sampler != "pymc" :
66+ pytest .importorskip (nuts_sampler )
67+
68+ sample_kwargs = dict (
69+ nuts_sampler = nuts_sampler ,
70+ random_seed = 123 ,
71+ chains = 2 ,
72+ tune = 500 ,
73+ draws = 500 ,
74+ progressbar = False ,
75+ initvals = {"x" : 0.0 },
76+ )
77+
78+ with pymc_model :
79+ idata1 = sample (** sample_kwargs )
80+ idata2 = sample (** sample_kwargs )
81+
82+ warns = {
83+ (warn .category , warn .message .args [0 ])
84+ for warn in recwarn
85+ if warn .category not in (FutureWarning , DeprecationWarning , RuntimeWarning )
86+ }
87+
88+ check_external_sampler_output (warns , idata1 , idata2 , sample_kwargs )
89+
90+
91+ @pytest .mark .parametrize ("backend" , ["numba" , "jax" ], ids = ["numba" , "jax" ])
92+ def test_numba_backend_options (pymc_model , recwarn , backend ):
93+ pytest .importorskip ("nutpie" )
94+ pytest .importorskip (backend )
95+
96+ sample_kwargs = dict (
97+ nuts_sampler = f"nutpie[{ backend } ]" ,
98+ random_seed = 123 ,
99+ chains = 2 ,
100+ tune = 500 ,
101+ draws = 500 ,
102+ progressbar = False ,
103+ initvals = {"x" : 0.0 },
104+ )
105+
106+ with pymc_model :
107+ idata1 = sample (** sample_kwargs )
108+ idata2 = sample (** sample_kwargs )
109+
110+ warns = {
111+ (warn .category , warn .message .args [0 ])
112+ for warn in recwarn
113+ if warn .category not in (FutureWarning , DeprecationWarning , RuntimeWarning )
114+ }
115+
116+ check_external_sampler_output (warns , idata1 , idata2 , sample_kwargs )
117+
118+
119+ def test_invalid_nutpie_backend_raises (pymc_model ):
120+ with pytest .raises (ValueError , match = 'Expected one of "numba" or "jax"; found "invalid"' ):
121+ with pymc_model :
122+ sample (nuts_sampler = "nutpie[invalid]" , random_seed = 123 , chains = 2 , tune = 500 , draws = 500 )
123+
124+
77125def test_step_args ():
126+ pytest .importorskip ("numpyro" )
127+
78128 with Model () as model :
79129 a = Normal ("a" )
80130 idata = sample (
0 commit comments