1616import numpy .testing as npt
1717import pytest
1818
19- import pymc as pm
19+ from pymc import Data , Deterministic , HalfNormal , Model , Normal , sample
2020
2121
2222@pytest .mark .parametrize ("nuts_sampler" , ["pymc" , "nutpie" , "blackjax" , "numpyro" ])
2323def test_external_nuts_sampler (recwarn , nuts_sampler ):
2424 if nuts_sampler != "pymc" :
2525 pytest .importorskip (nuts_sampler )
2626
27- with pm . Model ():
28- x = pm . Normal ("x" , 100 , 5 )
29- y = pm . Data ("y" , [1 , 2 , 3 , 4 ])
30- pm . Data ("z" , [100 , 190 , 310 , 405 ])
27+ with Model ():
28+ x = Normal ("x" , 100 , 5 )
29+ y = Data ("y" , [1 , 2 , 3 , 4 ])
30+ Data ("z" , [100 , 190 , 310 , 405 ])
3131
32- pm . Normal ("L" , mu = x , sigma = 0.1 , observed = y )
32+ Normal ("L" , mu = x , sigma = 0.1 , observed = y )
3333
3434 kwargs = {
3535 "nuts_sampler" : nuts_sampler ,
@@ -41,12 +41,12 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
4141 "initvals" : {"x" : 0.0 },
4242 }
4343
44- idata1 = pm . sample (** kwargs )
45- idata2 = pm . sample (** kwargs )
44+ idata1 = sample (** kwargs )
45+ idata2 = sample (** kwargs )
4646
4747 reference_kwargs = kwargs .copy ()
4848 reference_kwargs ["nuts_sampler" ] = "pymc"
49- idata_reference = pm . sample (** reference_kwargs )
49+ idata_reference = sample (** reference_kwargs )
5050
5151 warns = {
5252 (warn .category , warn .message .args [0 ])
@@ -75,9 +75,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
7575
7676
7777def test_step_args ():
78- with pm . Model () as model :
79- a = pm . Normal ("a" )
80- idata = pm . sample (
78+ with Model () as model :
79+ a = Normal ("a" )
80+ idata = sample (
8181 nuts_sampler = "numpyro" ,
8282 target_accept = 0.5 ,
8383 nuts = {"max_treedepth" : 10 },
@@ -108,17 +108,17 @@ def test_sample_var_names(nuts_sampler):
108108 coords = {"group" : group_values }
109109
110110 # Create model
111- with pm . Model (coords = coords ) as model :
112- b_group = pm . Normal ("b_group" , dims = "group" )
113- b_x = pm . Normal ("b_x" )
114- mu = pm . Deterministic ("mu" , b_group [group_idx ] + b_x * x )
115- sigma = pm . HalfNormal ("sigma" )
116- pm . Normal ("y" , mu = mu , sigma = sigma , observed = y )
111+ with Model (coords = coords ) as model :
112+ b_group = Normal ("b_group" , dims = "group" )
113+ b_x = Normal ("b_x" )
114+ mu = Deterministic ("mu" , b_group [group_idx ] + b_x * x )
115+ sigma = HalfNormal ("sigma" )
116+ Normal ("y" , mu = mu , sigma = sigma , observed = y )
117117
118118 # Sample with and without var_names, but always with the same seed
119119 with model :
120- idata_1 = pm . sample (tune = 100 , draws = 100 , random_seed = seed , ** kwargs )
121- idata_2 = pm . sample (
120+ idata_1 = sample (tune = 100 , draws = 100 , random_seed = seed , ** kwargs )
121+ idata_2 = sample (
122122 tune = 100 , draws = 100 , var_names = ["b_group" , "b_x" , "sigma" ], random_seed = seed , ** kwargs
123123 )
124124
0 commit comments