Skip to content

Commit 1962645

Browse files
ColCarrolltwiecki
authored andcommitted
Actually set random seed on calls to sample (#1393)
1 parent 30b0c50 commit 1962645

File tree

1 file changed

+60
-70
lines changed

1 file changed

+60
-70
lines changed

pymc3/tests/test_sampling.py

Lines changed: 60 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,109 +5,99 @@
55
import mock
66
import unittest
77

8-
import pymc3
9-
from pymc3 import sampling
10-
from pymc3.sampling import sample
8+
import pymc3 as pm
9+
import theano.tensor as tt
10+
from theano import shared
1111
from .models import simple_init
1212
from .helpers import SeededTest
1313

1414
# Test if multiprocessing is available
1515
import multiprocessing
1616
try:
1717
multiprocessing.Pool(2)
18-
test_parallel = False
1918
except:
20-
test_parallel = False
19+
pass
2120

2221

23-
def test_sample():
24-
model, start, step, _ = simple_init()
25-
test_njobs = [1]
26-
with model:
27-
for njobs in test_njobs:
28-
for n in [1, 10, 300]:
29-
yield sample, n, step, {}, None, njobs
22+
class TestSample(SeededTest):
23+
def setUp(self):
24+
super(TestSample, self).setUp()
25+
self.model, self.start, self.step, _ = simple_init()
3026

27+
def test_sample(self):
28+
test_njobs = [1]
29+
with self.model:
30+
for njobs in test_njobs:
31+
for steps in [1, 10, 300]:
32+
pm.sample(steps, self.step, {}, None, njobs=njobs, random_seed=self.random_seed)
3133

32-
def test_iter_sample():
33-
model, start, step, _ = simple_init()
34-
samps = sampling.iter_sample(5, step, start, model=model)
35-
for i, trace in enumerate(samps):
36-
assert i == len(trace) - 1, "Trace does not have correct length."
34+
def test_iter_sample(self):
35+
with self.model:
36+
samps = pm.sampling.iter_sample(5, self.step, self.start, random_seed=self.random_seed)
37+
for i, trace in enumerate(samps):
38+
self.assertEqual(i, len(trace) - 1, "Trace does not have correct length.")
3739

38-
39-
class TestParallelStart(SeededTest):
4040
def test_parallel_start(self):
41-
model, _, _, _ = simple_init()
42-
with model:
43-
tr = sample(5, njobs=2, start=[{'x': [10, 10]}, {'x': [-10, -10]}])
41+
with self.model:
42+
tr = pm.sample(5, njobs=2, start=[{'x': [10, 10]}, {'x': [-10, -10]}],
43+
random_seed=self.random_seed)
4444
self.assertGreater(tr.get_values('x', chains=0)[0][0], 0)
4545
self.assertLess(tr.get_values('x', chains=1)[0][0], 0)
4646

4747

48-
def test_soft_update_all_present():
49-
start = {'a': 1, 'b': 2}
50-
test_point = {'a': 3, 'b': 4}
51-
sampling._soft_update(start, test_point)
52-
assert start == {'a': 1, 'b': 2}
53-
48+
class SoftUpdate(SeededTest):
49+
def test_soft_update_all_present(self):
50+
start = {'a': 1, 'b': 2}
51+
test_point = {'a': 3, 'b': 4}
52+
pm.sampling._soft_update(start, test_point)
53+
self.assertDictEqual(start, {'a': 1, 'b': 2})
5454

55-
def test_soft_update_one_missing():
56-
start = {'a': 1, }
57-
test_point = {'a': 3, 'b': 4}
58-
sampling._soft_update(start, test_point)
59-
assert start == {'a': 1, 'b': 4}
55+
def test_soft_update_one_missing(self):
56+
start = {'a': 1, }
57+
test_point = {'a': 3, 'b': 4}
58+
pm.sampling._soft_update(start, test_point)
59+
self.assertDictEqual(start, {'a': 1, 'b': 4})
6060

61-
62-
def test_soft_update_empty():
63-
start = {}
64-
test_point = {'a': 3, 'b': 4}
65-
sampling._soft_update(start, test_point)
66-
assert start == test_point
61+
def test_soft_update_empty(self):
62+
start = {}
63+
test_point = {'a': 3, 'b': 4}
64+
pm.sampling._soft_update(start, test_point)
65+
self.assertDictEqual(start, test_point)
6766

6867

6968
class TestNamedSampling(SeededTest):
7069
def test_shared_named(self):
71-
from theano import shared
72-
import theano.tensor as tt
73-
7470
G_var = shared(value=np.atleast_2d(1.), broadcastable=(True, False),
7571
name="G")
7672

77-
with pymc3.Model():
78-
theta0 = pymc3.Normal('theta0', mu=np.atleast_2d(0),
79-
tau=np.atleast_2d(1e20), shape=(1, 1),
80-
testval=np.atleast_2d(0))
81-
theta = pymc3.Normal('theta', mu=tt.dot(G_var, theta0),
82-
tau=np.atleast_2d(1e20), shape=(1, 1))
83-
73+
with pm.Model():
74+
theta0 = pm.Normal('theta0', mu=np.atleast_2d(0),
75+
tau=np.atleast_2d(1e20), shape=(1, 1),
76+
testval=np.atleast_2d(0))
77+
theta = pm.Normal('theta', mu=tt.dot(G_var, theta0),
78+
tau=np.atleast_2d(1e20), shape=(1, 1))
8479
res = theta.random()
8580
assert np.isclose(res, 0.)
8681

8782
def test_shared_unnamed(self):
88-
from theano import shared
89-
import theano.tensor as tt
9083
G_var = shared(value=np.atleast_2d(1.), broadcastable=(True, False))
91-
with pymc3.Model():
92-
theta0 = pymc3.Normal('theta0', mu=np.atleast_2d(0),
93-
tau=np.atleast_2d(1e20), shape=(1, 1),
94-
testval=np.atleast_2d(0))
95-
theta = pymc3.Normal('theta', mu=tt.dot(G_var, theta0),
96-
tau=np.atleast_2d(1e20), shape=(1, 1))
97-
84+
with pm.Model():
85+
theta0 = pm.Normal('theta0', mu=np.atleast_2d(0),
86+
tau=np.atleast_2d(1e20), shape=(1, 1),
87+
testval=np.atleast_2d(0))
88+
theta = pm.Normal('theta', mu=tt.dot(G_var, theta0),
89+
tau=np.atleast_2d(1e20), shape=(1, 1))
9890
res = theta.random()
9991
assert np.isclose(res, 0.)
10092

10193
def test_constant_named(self):
102-
import theano.tensor as tt
103-
10494
G_var = tt.constant(np.atleast_2d(1.), name="G")
105-
with pymc3.Model():
106-
theta0 = pymc3.Normal('theta0', mu=np.atleast_2d(0),
107-
tau=np.atleast_2d(1e20), shape=(1, 1),
108-
testval=np.atleast_2d(0))
109-
theta = pymc3.Normal('theta', mu=tt.dot(G_var, theta0),
110-
tau=np.atleast_2d(1e20), shape=(1, 1))
95+
with pm.Model():
96+
theta0 = pm.Normal('theta0', mu=np.atleast_2d(0),
97+
tau=np.atleast_2d(1e20), shape=(1, 1),
98+
testval=np.atleast_2d(0))
99+
theta = pm.Normal('theta', mu=tt.dot(G_var, theta0),
100+
tau=np.atleast_2d(1e20), shape=(1, 1))
111101

112102
res = theta.random()
113103
assert np.isclose(res, 0.)
@@ -116,22 +106,22 @@ def test_constant_named(self):
116106
class TestChooseBackend(unittest.TestCase):
117107
def test_choose_backend_none(self):
118108
with mock.patch('pymc3.sampling.NDArray') as nd:
119-
sampling._choose_backend(None, 'chain')
109+
pm.sampling._choose_backend(None, 'chain')
120110
self.assertTrue(nd.called)
121111

122112
def test_choose_backend_list_of_variables(self):
123113
with mock.patch('pymc3.sampling.NDArray') as nd:
124-
sampling._choose_backend(['var1', 'var2'], 'chain')
114+
pm.sampling._choose_backend(['var1', 'var2'], 'chain')
125115
nd.assert_called_with(vars=['var1', 'var2'])
126116

127117
def test_choose_backend_invalid(self):
128118
self.assertRaises(ValueError,
129-
sampling._choose_backend,
119+
pm.sampling._choose_backend,
130120
'invalid', 'chain')
131121

132122
def test_choose_backend_shortcut(self):
133123
backend = mock.Mock()
134124
shortcuts = {'test_backend': {'backend': backend,
135125
'name': None}}
136-
sampling._choose_backend('test_backend', 'chain', shortcuts=shortcuts)
126+
pm.sampling._choose_backend('test_backend', 'chain', shortcuts=shortcuts)
137127
self.assertTrue(backend.called)

0 commit comments

Comments
 (0)