Skip to content

Commit e9ee56a

Browse files
colintwiecki
authored andcommitted
Only set random seed on purpose
1 parent 0ca78be commit e9ee56a

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

pymc3/sampling.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropol
8080

8181

8282
def sample(draws, step=None, start=None, trace=None, chain=0, njobs=1, tune=None,
83-
progressbar=True, model=None, random_seed=None):
83+
progressbar=True, model=None, random_seed=-1):
8484
"""
8585
Draw a number of samples using the given step method.
8686
Multiple step methods supported via compound step method
@@ -116,8 +116,8 @@ def sample(draws, step=None, start=None, trace=None, chain=0, njobs=1, tune=None
116116
tune : int
117117
Number of iterations to tune, if applicable (defaults to None)
118118
progressbar : bool
119-
Whether or not to display a progress bar in the command line. The
120-
bar shows the percentage of completion, the sampling speed in
119+
Whether or not to display a progress bar in the command line. The
120+
bar shows the percentage of completion, the sampling speed in
121121
samples per second (SPS), and the estimated remaining time until
122122
completion ("expected time of arrival"; ETA).
123123
model : Model (optional if in `with` context)
@@ -156,7 +156,7 @@ def sample(draws, step=None, start=None, trace=None, chain=0, njobs=1, tune=None
156156

157157

158158
def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
159-
progressbar=True, model=None, random_seed=None):
159+
progressbar=True, model=None, random_seed=-1):
160160
sampling = _iter_sample(draws, step, start, trace, chain,
161161
tune, model, random_seed)
162162
progress = progress_bar(draws)
@@ -170,7 +170,7 @@ def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
170170

171171

172172
def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
173-
model=None, random_seed=None):
173+
model=None, random_seed=-1):
174174
"""
175175
Generator that returns a trace on each iteration using the given
176176
step method. Multiple step methods supported via compound step
@@ -215,10 +215,11 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
215215

216216

217217
def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
218-
model=None, random_seed=None):
218+
model=None, random_seed=-1):
219219
model = modelcontext(model)
220220
draws = int(draws)
221-
seed(random_seed)
221+
if random_seed != -1:
222+
seed(random_seed)
222223
if draws < 1:
223224
raise ValueError('Argument `draws` should be above 0.')
224225

pymc3/tests/test_sampling.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ def setUp(self):
2424
super(TestSample, self).setUp()
2525
self.model, self.start, self.step, _ = simple_init()
2626

27+
def test_sample_does_not_set_seed(self):
28+
random_numbers = []
29+
for _ in range(2):
30+
np.random.seed(1)
31+
with self.model:
32+
pm.sample(1)
33+
random_numbers.append(np.random.random())
34+
self.assertEqual(random_numbers[0], random_numbers[1])
35+
2736
def test_sample(self):
2837
test_njobs = [1]
2938
with self.model:

0 commit comments

Comments
 (0)