Skip to content

Commit 6f95e04

Browse files
jackhansomaseyboldt
authored andcommitted
Adds a check in sample for start kwarg shapes (#2462)
* Adds test for shape of start argument * Adds case for iterable of start arguments and case when there is no shape * addresses comments to sample start check PR * adds space between string literals * adds return to _check_start_shape * adds unit tests for _check_start_shape
1 parent fcb23f9 commit 6f95e04

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

pymc3/sampling.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import defaultdict
1+
from collections import defaultdict, Sequence
22

33
from joblib import Parallel, delayed
44
from numpy.random import randint, seed
@@ -144,7 +144,7 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
144144
n_init : int
145145
Number of iterations of initializer
146146
If 'ADVI', number of iterations, if 'nuts', number of draws.
147-
start : dict
147+
start : dict, or array of dict
148148
Starting point in parameter space (or partial point)
149149
Defaults to trace.point(-1)) if there is a trace provided and
150150
model.test_point if not (defaults to empty dict).
@@ -227,6 +227,9 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
227227
"""
228228
model = modelcontext(model)
229229

230+
if start is not None:
231+
_check_start_shape(model, start)
232+
230233
draws += tune
231234

232235
if nuts_kwargs is not None:
@@ -280,6 +283,38 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
280283
return sample_func(**sample_args)[discard:]
281284

282285

286+
def _check_start_shape(model, start):
287+
e = ''
288+
if isinstance(start, (Sequence, np.ndarray)):
289+
# to deal with iterable start argument
290+
for start_iter in start:
291+
_check_start_shape(model, start_iter)
292+
return
293+
elif not isinstance(start, dict):
294+
raise TypeError("start argument must be a dict "
295+
"or an array-like of dicts")
296+
for var in model.vars:
297+
if var.name in start.keys():
298+
var_shape = var.shape.tag.test_value
299+
start_var_shape = np.shape(start[var.name])
300+
if start_var_shape:
301+
if not np.array_equal(var_shape, start_var_shape):
302+
e += "\nExpected shape {} for var '{}', got: {}".format(
303+
tuple(var_shape), var.name, start_var_shape
304+
)
305+
# if start var has no shape
306+
else:
307+
# if model var has a specified shape
308+
if var_shape:
309+
e += "\nExpected shape {} for var " \
310+
"'{}', got scalar {}".format(
311+
tuple(var_shape), var.name, start[var.name]
312+
)
313+
314+
if e != '':
315+
raise ValueError("Bad shape for start argument:{}".format(e))
316+
317+
283318
def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
284319
progressbar=True, model=None, random_seed=-1, live_plot=False,
285320
live_plot_kwargs=None, **kwargs):

pymc3/tests/test_sampling.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,27 @@ def test_sample_tune_len(self):
110110
trace = pm.sample(draws=100, tune=50, njobs=4)
111111
assert len(trace) == 100
112112

113+
@pytest.mark.parametrize(
114+
'start, error', [
115+
([1, 2], TypeError),
116+
({'x': 1}, ValueError),
117+
({'x': [1, 2, 3]}, ValueError),
118+
({'x': np.array([[1, 1], [1, 1]])}, ValueError)
119+
]
120+
)
121+
def test_sample_start_bad_shape(self, start, error):
122+
with pytest.raises(error):
123+
pm.sampling._check_start_shape(self.model, start)
124+
125+
@pytest.mark.parametrize(
126+
'start', [
127+
{'x': np.array([1, 1])},
128+
[{'x': [10, 10]}, {'x': [-10, -10]}]
129+
]
130+
)
131+
def test_sample_start_good_shape(self, start):
132+
pm.sampling._check_start_shape(self.model, start)
133+
113134

114135
def test_empty_model():
115136
with pm.Model():

0 commit comments

Comments
 (0)