5
5
import mock
6
6
import unittest
7
7
8
- import pymc3
9
- from pymc3 import sampling
10
- from pymc3 . sampling import sample
8
+ import pymc3 as pm
9
+ import theano . tensor as tt
10
+ from theano import shared
11
11
from .models import simple_init
12
12
from .helpers import SeededTest
13
13
14
14
# Test if multiprocessing is available
15
15
import multiprocessing
16
16
try :
17
17
multiprocessing .Pool (2 )
18
- test_parallel = False
19
18
except :
20
- test_parallel = False
19
+ pass
21
20
22
21
23
- def test_sample ():
24
- model , start , step , _ = simple_init ()
25
- test_njobs = [1 ]
26
- with model :
27
- for njobs in test_njobs :
28
- for n in [1 , 10 , 300 ]:
29
- yield sample , n , step , {}, None , njobs
22
+ class TestSample (SeededTest ):
23
+ def setUp (self ):
24
+ super (TestSample , self ).setUp ()
25
+ self .model , self .start , self .step , _ = simple_init ()
30
26
27
+ def test_sample (self ):
28
+ test_njobs = [1 ]
29
+ with self .model :
30
+ for njobs in test_njobs :
31
+ for steps in [1 , 10 , 300 ]:
32
+ pm .sample (steps , self .step , {}, None , njobs = njobs , random_seed = self .random_seed )
31
33
32
- def test_iter_sample ():
33
- model , start , step , _ = simple_init ()
34
- samps = sampling .iter_sample (5 , step , start , model = model )
35
- for i , trace in enumerate (samps ):
36
- assert i == len (trace ) - 1 , "Trace does not have correct length."
34
+ def test_iter_sample (self ):
35
+ with self . model :
36
+ samps = pm . sampling .iter_sample (5 , self . step , self . start , random_seed = self . random_seed )
37
+ for i , trace in enumerate (samps ):
38
+ self . assertEqual ( i , len (trace ) - 1 , "Trace does not have correct length." )
37
39
38
-
39
- class TestParallelStart (SeededTest ):
40
40
def test_parallel_start (self ):
41
- model , _ , _ , _ = simple_init ()
42
- with model :
43
- tr = sample ( 5 , njobs = 2 , start = [{ 'x' : [ 10 , 10 ]}, { 'x' : [ - 10 , - 10 ]}] )
41
+ with self . model :
42
+ tr = pm . sample ( 5 , njobs = 2 , start = [{ 'x' : [ 10 , 10 ]}, { 'x' : [ - 10 , - 10 ]}],
43
+ random_seed = self . random_seed )
44
44
self .assertGreater (tr .get_values ('x' , chains = 0 )[0 ][0 ], 0 )
45
45
self .assertLess (tr .get_values ('x' , chains = 1 )[0 ][0 ], 0 )
46
46
47
47
48
- def test_soft_update_all_present ( ):
49
- start = { 'a' : 1 , 'b' : 2 }
50
- test_point = {'a' : 3 , 'b' : 4 }
51
- sampling . _soft_update ( start , test_point )
52
- assert start == { 'a' : 1 , 'b' : 2 }
53
-
48
+ class SoftUpdate ( SeededTest ):
49
+ def test_soft_update_all_present ( self ):
50
+ start = {'a' : 1 , 'b' : 2 }
51
+ test_point = { 'a' : 3 , 'b' : 4 }
52
+ pm . sampling . _soft_update ( start , test_point )
53
+ self . assertDictEqual ( start , { 'a' : 1 , 'b' : 2 })
54
54
55
- def test_soft_update_one_missing ():
56
- start = {'a' : 1 , }
57
- test_point = {'a' : 3 , 'b' : 4 }
58
- sampling ._soft_update (start , test_point )
59
- assert start == {'a' : 1 , 'b' : 4 }
55
+ def test_soft_update_one_missing (self ):
56
+ start = {'a' : 1 , }
57
+ test_point = {'a' : 3 , 'b' : 4 }
58
+ pm . sampling ._soft_update (start , test_point )
59
+ self . assertDictEqual ( start , {'a' : 1 , 'b' : 4 })
60
60
61
-
62
- def test_soft_update_empty ():
63
- start = {}
64
- test_point = {'a' : 3 , 'b' : 4 }
65
- sampling ._soft_update (start , test_point )
66
- assert start == test_point
61
+ def test_soft_update_empty (self ):
62
+ start = {}
63
+ test_point = {'a' : 3 , 'b' : 4 }
64
+ pm .sampling ._soft_update (start , test_point )
65
+ self .assertDictEqual (start , test_point )
67
66
68
67
69
68
class TestNamedSampling (SeededTest ):
70
69
def test_shared_named (self ):
71
- from theano import shared
72
- import theano .tensor as tt
73
-
74
70
G_var = shared (value = np .atleast_2d (1. ), broadcastable = (True , False ),
75
71
name = "G" )
76
72
77
- with pymc3 .Model ():
78
- theta0 = pymc3 .Normal ('theta0' , mu = np .atleast_2d (0 ),
79
- tau = np .atleast_2d (1e20 ), shape = (1 , 1 ),
80
- testval = np .atleast_2d (0 ))
81
- theta = pymc3 .Normal ('theta' , mu = tt .dot (G_var , theta0 ),
82
- tau = np .atleast_2d (1e20 ), shape = (1 , 1 ))
83
-
73
+ with pm .Model ():
74
+ theta0 = pm .Normal ('theta0' , mu = np .atleast_2d (0 ),
75
+ tau = np .atleast_2d (1e20 ), shape = (1 , 1 ),
76
+ testval = np .atleast_2d (0 ))
77
+ theta = pm .Normal ('theta' , mu = tt .dot (G_var , theta0 ),
78
+ tau = np .atleast_2d (1e20 ), shape = (1 , 1 ))
84
79
res = theta .random ()
85
80
assert np .isclose (res , 0. )
86
81
87
82
def test_shared_unnamed (self ):
88
- from theano import shared
89
- import theano .tensor as tt
90
83
G_var = shared (value = np .atleast_2d (1. ), broadcastable = (True , False ))
91
- with pymc3 .Model ():
92
- theta0 = pymc3 .Normal ('theta0' , mu = np .atleast_2d (0 ),
93
- tau = np .atleast_2d (1e20 ), shape = (1 , 1 ),
94
- testval = np .atleast_2d (0 ))
95
- theta = pymc3 .Normal ('theta' , mu = tt .dot (G_var , theta0 ),
96
- tau = np .atleast_2d (1e20 ), shape = (1 , 1 ))
97
-
84
+ with pm .Model ():
85
+ theta0 = pm .Normal ('theta0' , mu = np .atleast_2d (0 ),
86
+ tau = np .atleast_2d (1e20 ), shape = (1 , 1 ),
87
+ testval = np .atleast_2d (0 ))
88
+ theta = pm .Normal ('theta' , mu = tt .dot (G_var , theta0 ),
89
+ tau = np .atleast_2d (1e20 ), shape = (1 , 1 ))
98
90
res = theta .random ()
99
91
assert np .isclose (res , 0. )
100
92
101
93
def test_constant_named (self ):
102
- import theano .tensor as tt
103
-
104
94
G_var = tt .constant (np .atleast_2d (1. ), name = "G" )
105
- with pymc3 .Model ():
106
- theta0 = pymc3 .Normal ('theta0' , mu = np .atleast_2d (0 ),
107
- tau = np .atleast_2d (1e20 ), shape = (1 , 1 ),
108
- testval = np .atleast_2d (0 ))
109
- theta = pymc3 .Normal ('theta' , mu = tt .dot (G_var , theta0 ),
110
- tau = np .atleast_2d (1e20 ), shape = (1 , 1 ))
95
+ with pm .Model ():
96
+ theta0 = pm .Normal ('theta0' , mu = np .atleast_2d (0 ),
97
+ tau = np .atleast_2d (1e20 ), shape = (1 , 1 ),
98
+ testval = np .atleast_2d (0 ))
99
+ theta = pm .Normal ('theta' , mu = tt .dot (G_var , theta0 ),
100
+ tau = np .atleast_2d (1e20 ), shape = (1 , 1 ))
111
101
112
102
res = theta .random ()
113
103
assert np .isclose (res , 0. )
@@ -116,22 +106,22 @@ def test_constant_named(self):
116
106
class TestChooseBackend (unittest .TestCase ):
117
107
def test_choose_backend_none (self ):
118
108
with mock .patch ('pymc3.sampling.NDArray' ) as nd :
119
- sampling ._choose_backend (None , 'chain' )
109
+ pm . sampling ._choose_backend (None , 'chain' )
120
110
self .assertTrue (nd .called )
121
111
122
112
def test_choose_backend_list_of_variables (self ):
123
113
with mock .patch ('pymc3.sampling.NDArray' ) as nd :
124
- sampling ._choose_backend (['var1' , 'var2' ], 'chain' )
114
+ pm . sampling ._choose_backend (['var1' , 'var2' ], 'chain' )
125
115
nd .assert_called_with (vars = ['var1' , 'var2' ])
126
116
127
117
def test_choose_backend_invalid (self ):
128
118
self .assertRaises (ValueError ,
129
- sampling ._choose_backend ,
119
+ pm . sampling ._choose_backend ,
130
120
'invalid' , 'chain' )
131
121
132
122
def test_choose_backend_shortcut (self ):
133
123
backend = mock .Mock ()
134
124
shortcuts = {'test_backend' : {'backend' : backend ,
135
125
'name' : None }}
136
- sampling ._choose_backend ('test_backend' , 'chain' , shortcuts = shortcuts )
126
+ pm . sampling ._choose_backend ('test_backend' , 'chain' , shortcuts = shortcuts )
137
127
self .assertTrue (backend .called )
0 commit comments