Skip to content

Commit 6fd230f

Browse files
TStescoaloctavodia
authored andcommitted
Changes sample_smc to call _initial_population and generate starting points only if starting points are not provided in start param. Also adding reference values to test_step.py for smc with provided start points. (#3062)
1 parent 625435e commit 6fd230f

File tree

2 files changed

+43
-48
lines changed

2 files changed

+43
-48
lines changed

pymc3/step_methods/smc.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -460,13 +460,6 @@ def sample_smc(samples=1000, chains=100, step=None, start=None, homepath=None, s
460460
if not (chains / float(cores)).is_integer():
461461
raise TypeError('chains / cores has to be a whole number!')
462462

463-
if start is not None:
464-
if len(start) != chains:
465-
raise TypeError('Argument `start` should have dicts equal the '
466-
'number of chains (`chains`)')
467-
else:
468-
step.population = start
469-
470463
if not any(step.likelihood_name in var.name for var in model.deterministics):
471464
raise TypeError('Model (deterministic) variables need to contain a variable {} as defined '
472465
'in `step`.'.format(step.likelihood_name))
@@ -490,7 +483,15 @@ def sample_smc(samples=1000, chains=100, step=None, start=None, homepath=None, s
490483

491484
step.resampling_indexes = np.arange(chains)
492485
step.proposal_samples_array = step.proposal_dist(chains)
493-
step.population = _initial_population(samples, chains, model, step.vars)
486+
487+
if start is not None:
488+
if len(start) != chains:
489+
raise TypeError('Argument `start` should have dicts equal the '
490+
'number of chains (`chains`)')
491+
else:
492+
step.population = start
493+
else:
494+
step.population = _initial_population(samples, chains, model, step.vars)
494495

495496
with model:
496497
while step.beta < 1:

pymc3/tests/test_step.py

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -124,46 +124,40 @@ class TestStepMethods(object): # yield test doesn't work subclassing object
124124
-2.24238542e+00, -1.01648100e+00, -1.01648100e+00, -7.60912865e-01,
125125
1.44384812e+00, 2.07355127e+00, 1.91390340e+00, 1.66559696e+00]),
126126
smc.SMC: np.array(
127-
[ 0.94245927, 0.04320349, 0.16616453, -0.42667441, -0.49780471,
128-
0.65384837, -0.25387836, 0.38232654, 0.62490342, -0.21777828,
129-
-0.70756665, 0.9310788 , -0.03941721, -1.20854932, 0.39442244,
130-
0.24306076, -0.98310433, 2.2503327 , 0.54090823, 0.51685018,
131-
-1.32968792, 0.02445827, -0.62052594, -0.28014643, 0.75977904,
132-
-1.20233021, -1.80432242, -0.31547627, -0.33392375, -1.34380736,
133-
1.44597486, -0.15871648, -0.20727832, 0.99115736, 0.3445085 ,
134-
-0.89909578, -0.36983042, 0.16734659, 0.13228431, -0.16786514,
135-
-0.36268027, 0.13369353, -1.28444377, 1.2644179 , -0.47877275,
136-
-0.4411035 , 0.35735115, -1.27425973, -0.43213873, 0.70698702,
137-
-0.7805279 , -1.67705636, -0.10661104, -0.59947856, 0.02693047,
138-
-1.09062222, -0.73592286, -1.56822784, 0.97077952, -0.02149393,
139-
-0.26597767, -0.38710878, -0.09971606, -0.52523725, 1.64000249,
140-
-0.1287883 , 0.09555045, 0.04258323, -0.16771237, 0.79324588,
141-
-0.4439878 , -0.00328163, 0.01267578, 0.31817545, -2.48389205,
142-
-0.43794095, -0.18922707, 0.0042956 , 0.29387263, 0.66119344,
143-
-0.98277349, 0.4039511 , 0.13542066, -0.78467059, -0.24334413,
144-
-0.62519786, -0.79586084, -0.06190844, 0.11355637, 0.66110093,
145-
-2.10383759, 0.48608459, -0.47993295, 0.46791254, 2.01963317,
146-
0.12975299, 1.71604836, -0.09413096, 0.30744711, 0.15079852,
147-
0.31349994, 0.26575959, 0.763656 , -1.81526952, -0.22984443,
148-
1.10531065, 0.26065936, -0.22274362, -0.20853456, 0.32741584,
149-
0.08521911, -1.53866503, 0.28501159, -0.39016642, 0.09505455,
150-
-0.72955337, 1.46268494, 0.56252715, -1.63048738, 1.45718808,
151-
-0.01141763, 0.65826932, 1.8723026 , 0.90744555, 1.40586042,
152-
1.58765986, 0.06792152, -0.71397153, 0.22718523, -1.90281392,
153-
0.58708453, -0.77195137, -0.56979511, -0.6543881 , -1.3711677 ,
154-
-1.72706576, -0.41484231, 0.17460229, 0.74160523, 0.10991525,
155-
0.50297247, 1.04762338, -0.69148618, 1.23291629, -0.49797445,
156-
-0.24914585, 1.44290113, -0.23288806, -1.15495976, 0.63230627,
157-
-1.06229509, 0.18047975, -1.23701009, 0.10994274, -0.81730888,
158-
0.01827404, -0.22824212, -0.76809243, -1.36315643, 0.76097799,
159-
1.51091188, 0.46931765, 1.27261922, 0.98191306, 0.80721561,
160-
1.12844558, 1.86799414, 0.29913787, -1.49977561, 0.7551137 ,
161-
-1.0622067 , -0.46200335, -0.10271276, -0.63924651, 1.56074961,
162-
-0.53611858, -0.23229769, -0.74455411, -2.41567262, -0.96658159,
163-
-0.08795562, 0.08532369, -1.56005584, -0.99356212, 0.32678269,
164-
-0.87012306, 0.83897514, 0.9799229 , -1.27565975, -0.25761179,
165-
0.34968085, -0.95045095, 0.95192797, -1.5101461 , 0.04042998,
166-
-0.91145107, -0.91700215, 0.0825614 , 0.59658604, 0.64933802]),
127+
[ 0.61562138, -0.56082978, -0.89760381, 1.47368457, 0.33300527, 0.85567605,
128+
-1.33503519, -1.47996682, -0.3725601, 0.75713321, 1.81055917, 0.39193534,
129+
0.10083821, 0.55569412, -0.65879812, -0.61545061, -2.65522875, 0.93801687,
130+
2.40499211, -0.63022535, 0.09565784, -1.00650846, 1.65901231, 0.18429996,
131+
1.64642521, 0.5589963, -0.40452525, -0.9402324, 0.53813986, 0.55785946,
132+
1.22966132, 0.2782562, -0.81254158, -0.08076293, -0.29136329, 0.62914226,
133+
0.16049388, -0.06386387, 1.8103961, -0.98444811, -0.36333739, 0.88703339,
134+
-0.08482673, -0.23224262, -0.11348807, 1.09401682, -0.58594449, -0.12728503,
135+
-0.82408778, -1.82770764, -2.28859404, -0.51814943, -1.53652851, 0.66313366,
136+
1.61666698, 1.41284768, -0.05129251, 0.96166282, 1.00446145, -0.86380886,
137+
-1.13410885, -0.48311125, -1.25446622, -0.48452779, -0.84647195, -0.43201729,
138+
-1.22186151, 1.18698485, 0.33434434, -0.40650775, 0.47740064, 0.96943022,
139+
1.15534028, -0.86220564, -0.26049285, -1.17489183, 0.66796904, -1.68920203,
140+
-0.96308602, -1.73542483, -0.84744376, 0.91864514, -0.02724505, 0.16143404,
141+
0.65747707, -1.49655923, -0.32010575, 1.20255401, 0.1203948, -1.30017822,
142+
1.55895643, -0.74799042, -1.5938872, 0.69297144, -1.32932843, -0.16886992,
143+
-1.01437109, 0.32476589, 1.02509164, 0.31274278, -0.7908909, 1.18439217,
144+
-0.96132492, -0.4934065, 0.71438293, 0.09829997, 1.81936381, 0.47941016,
145+
0.3717936, 0.14339747, 1.24288736, 0.92520773, 0.69025067, 0.96618094,
146+
0.69085402, -1.12128175, 0.11228076, 0.7711306, 0.12859226, 0.65792466,
147+
-0.07422313, 1.74736739, 0.24120104, 0.74946338, 0.66260928, -0.34070258,
148+
1.09875434, -0.4159233, -0.01607339, 1.20921296, -0.29176047, 0.47367867,
149+
-1.45788116, -0.40198772, 0.44502909, 0.65623719, 0.99422221, 1.37241668,
150+
-0.05163759, 0.82729935, 0.59458429, 1.10870872, -1.00730291, -0.07837131,
151+
-0.28144688, -0.03052015, 1.05263496, 0.19011829, -0.98807301, -0.77388355,
152+
-1.68729554, 0.03018351, 0.39424573, 0.98343413, -1.40600196, 1.19764243,
153+
1.64712279, 0.68929684, -0.54301669, -0.29369924, 0.09052877, 2.64067523,
154+
-1.25887138, 1.65991714, 0.71271397, -0.50396329, 1.2182173, 0.2472108,
155+
-0.2990774, 0.1646579, 0.21418971, -0.0876372, 0.66714317, -0.43490764,
156+
-2.17899663, -0.2681325, -3.10431098, -1.38211864, 0.02041712, 0.16319981,
157+
-1.02526047, 1.93088335, -0.36975507, -0.61332039, 0.33666881, -0.23766903,
158+
-0.58478679, 1.38941035, -0.45829187, -1.12505096, -1.4814355, 0.61790977,
159+
0.58867984, 1.38693864, 1.80845772, -1.63246225, -1.48247172, -0.69197631,
160+
0.65045375, -0.09601979]),
167161
}
168162

169163
def setup_class(self):

0 commit comments

Comments
 (0)