Skip to content

Commit fb58bed

Browse files
aloctavodiatwiecki
authored andcommitted
SMC, use all the samples from the last stage (#2543)
* SMC, use all the samples from the last stage * revert removed idxs
1 parent 6966e8d commit fb58bed

File tree

4 files changed

+73
-48
lines changed

4 files changed

+73
-48
lines changed

docs/source/notebooks/SMC2_gaussians.ipynb

Lines changed: 23 additions & 22 deletions
Large diffs are not rendered by default.

pymc3/backends/smc_text.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ def create_result_trace(self, stage_number=-1, idxs=(-1, ), step=None, model=Non
305305
stage of which result traces are loaded
306306
idxs : iterable
307307
of indexes to the point at each chain to extract and concatenate
308-
309308
Returns
310309
-------
311310
MultiTrace
@@ -478,6 +477,6 @@ def point(self, idx):
478477
self._load_df()
479478
pt = {}
480479
for varname in self.varnames:
481-
vals = self.df[self.flat_names[varname]].iloc[idx]
480+
vals = self.df[self.flat_names[varname]].iloc[idx].values
482481
pt[varname] = vals.reshape(self.var_shapes[varname])
483482
return pt

pymc3/step_methods/smc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,10 @@ def sample_smc(n_steps, n_chains=100, step=None, start=None, homepath=None, stag
590590
_iter_parallel_chains(**sample_args)
591591

592592
stage_handler.dump_atmip_params(step)
593-
return stage_handler.create_result_trace(step.stage, step=step, model=model)
593+
return stage_handler.create_result_trace(step.stage,
594+
idxs=range(n_steps),
595+
step=step,
596+
model=model)
594597

595598

596599
def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,

pymc3/tests/test_step.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -124,27 +124,47 @@ class TestStepMethods(object): # yield test doesn't work subclassing object
124124
-3.91174647e-01, -2.60664979e+00, -2.27637534e+00, -2.81505065e+00,
125125
-2.24238542e+00, -1.01648100e+00, -1.01648100e+00, -7.60912865e-01,
126126
1.44384812e+00, 2.07355127e+00, 1.91390340e+00, 1.66559696e+00]),
127-
smc.SMC: np.array([
128-
-0.26421709, -2.07555186, 1.03443124, 0.16260898, -0.2809841 ,
129-
-0.35185097, -0.56387677, 0.18332851, 1.59614152, 0.39866217,
130-
-0.55781016, -0.74446992, 0.41198452, 0.47484429, 0.43417346,
131-
1.24153494, 1.10037457, 2.55408602, -1.47011338, 0.50824935,
132-
-2.09842977, 0.74269458, 0.31025837, 0.48376623, 1.74272003,
133-
-0.3975872 , -0.83735649, -0.33724478, 1.20300335, 1.40710795,
134-
-0.63740634, -0.33976389, -0.95412333, 1.84658352, 1.2000763 ,
135-
-1.08264783, -1.55367546, 0.66209331, 0.6577848 , 0.5727828 ,
136-
0.30248057, 0.89674302, 0.70148518, 0.56483303, 1.35161821,
137-
0.06392528, 0.70670242, 1.04846633, 0.54696351, -2.49061003,
138-
-1.29925327, -1.31906407, -0.36650058, -1.44809118, -0.96224606,
139-
-0.2501728 , -1.88779999, 0.35774637, 1.06917986, 2.07049617,
140-
-0.18667668, 0.19360673, -0.37665179, 0.98526962, 1.03010772,
141-
-0.25348684, 2.43418902, 0.89153789, -1.02035572, 1.77851957,
142-
0.6408621 , 0.50163095, 0.59934511, 0.73985647, 0.78719236,
143-
-0.41001864, -1.99859554, 1.53574307, -1.71336207, 1.04355849,
144-
0.21864817, -2.03911519, -0.42358936, -0.49666918, 1.64327219,
145-
-0.86416032, 1.10236002, 0.16396354, -0.13313781, 0.32649281,
146-
-1.01918397, 0.20525201, 1.04927506, 0.98243013, 2.46970704,
147-
-0.68709777, 2.05038381, 0.71417231, 1.13267395, -0.48644823]),
127+
smc.SMC: np.array(
128+
[ 1.30059573, -1. , 1.30059573, -1.28860918, 1.30059573,
129+
-1.33854363, 0.98809372, -0.22700433, 0.98809372, 0.45421367,
130+
0.53534095, -0.0571964 , 0.53534095, -0.58075355, 0.53534095,
131+
-0.81941713, 0.21067768, -0.77333386, 0.21067768, -0.57169475,
132+
0.21067768, -0.57169475, 0.71695573, -0.1735022 , 0.87048219,
133+
-0.28469019, 1.08731483, 0.08746968, 1.08059419, 0.08746968,
134+
0.31491769, -0.17753158, 0.48834878, 0.99152949, -0.1423678 ,
135+
0.1923664 , 0.06791856, -0.99708314, -0.13981681, -0.99708314,
136+
0.15039906, -0.99708314, 0.15039906, -0.6557885 , 0.15039906,
137+
-0.6557885 , 0.3553436 , -0.22781864, 0.3553436 , -0.86087058,
138+
0.3553436 , -1.26758014, 0.3553436 , -0.02546953, 1.992939 ,
139+
0.03739508, 1.992939 , 0.04077929, 1.47964467, -0.79954537,
140+
1.36470456, -1.28038148, 1.34975939, -1.28038148, 0.5058148 ,
141+
-1.28038148, 0.46681777, -1.51635697, 1.14761057, -1.51635697,
142+
0.70585017, -1.51635697, 0.56298035, -1.51635697, 0.68107999,
143+
-1.24900543, 0.68107999, -1.58687463, 0.8251361 , -0.30236423,
144+
0.19971902, 0.0871776 , 0.19971902, 0.51328569, 0.19971902,
145+
0.55526923, 0.19971902, 0.16065882, -0.87573391, 0.42539449,
146+
-0.87573391, 0.31060689, -0.87573391, 0.48370178, 0.27495794,
147+
0.48370178, 0.37129344, 0.48370178, 1.06413954, 0.48370178,
148+
1.57177313, 0.01683961, 1.75583481, 0.01683961, 1.87895941,
149+
0.49419352, 1.87895941, 0.49419352, 1.58832631, -0.02168877,
150+
1.58832631, 0.79617759, 1.41454982, 0.79617759, 1.93168471,
151+
0.78016131, 1.73345978, 0.19202933, 1.62254723, -0.22699057,
152+
1.62254723, -0.37699978, -0.14380698, -1.39915323, -0.0647066 ,
153+
-1.39915323, -0.27796904, -0.76046542, -1.39097353, -0.95882837,
154+
-2.37809137, -0.95882837, -0.76288136, -0.95882837, 0.10702519,
155+
0.6982565 , 0.47017639, 0.6982565 , 0.9479599 , 1.22950397,
156+
0.9479599 , 0.06416429, 0.9479599 , -0.0761023 , 1.18090459,
157+
-0.54169371, 1.18090459, 0.00447742, 0.93159018, 0.77757319,
158+
0.76557639, 0.77757319, 0.41731133, 0.77757319, 0.68380287,
159+
1.08320749, -0.43989818, 1.05260977, -0.32229488, 1.15892126,
160+
0.38764854, 1.15892126, 0.41491972, 0.71650402, -0.27923606,
161+
0.71650402, -1.32493526, -0.09371072, -0.74606271, 0.5227973 ,
162+
-0.74606271, 0.63974633, -1.41947892, -0.96488174, -1.294319 ,
163+
-0.96488174, -1.294319 , -0.78112189, -1.294319 , 0.45821163,
164+
-1.31499922, 0.54901984, -1.10260234, 1.12896946, 0.43768361,
165+
1.12896946, -0.58455279, 1.12896946, -0.58455279, 1.12896946,
166+
-0.58455279, 0.19820143, -1.19295628, -0.02548627, -1.19295628,
167+
0.50411866, -1.19295628, 0.50411866, -1.55631463, 0.92268245]),
148168
}
149169

150170
def setup_class(self):
@@ -180,9 +200,12 @@ def check_trace(self, step_method):
180200
x = Normal('x', mu=0, sd=1)
181201
if step_method.__name__ == 'SMC':
182202
trace = smc.sample_smc(n_steps=n_steps,
183-
step=step_method(random_seed=1),
203+
n_chains=2,
204+
start=[{'x':1.}, {'x':-1.}],
205+
random_seed=1,
184206
n_jobs=1, progressbar=False,
185207
homepath=self.temp_dir)
208+
186209
elif step_method.__name__ == 'NUTS':
187210
step = step_method(scaling=model.test_point)
188211
trace = sample(0, tune=n_steps,
@@ -192,7 +215,6 @@ def check_trace(self, step_method):
192215
trace = sample(0, tune=n_steps,
193216
discard_tuned_samples=False,
194217
step=step_method(), random_seed=1)
195-
196218
assert_array_almost_equal(
197219
trace.get_values('x'),
198220
self.master_samples[step_method],

0 commit comments

Comments
 (0)