@@ -124,27 +124,47 @@ class TestStepMethods(object): # yield test doesn't work subclassing object
124
124
- 3.91174647e-01 , - 2.60664979e+00 , - 2.27637534e+00 , - 2.81505065e+00 ,
125
125
- 2.24238542e+00 , - 1.01648100e+00 , - 1.01648100e+00 , - 7.60912865e-01 ,
126
126
1.44384812e+00 , 2.07355127e+00 , 1.91390340e+00 , 1.66559696e+00 ]),
127
- smc .SMC : np .array ([
128
- - 0.26421709 , - 2.07555186 , 1.03443124 , 0.16260898 , - 0.2809841 ,
129
- - 0.35185097 , - 0.56387677 , 0.18332851 , 1.59614152 , 0.39866217 ,
130
- - 0.55781016 , - 0.74446992 , 0.41198452 , 0.47484429 , 0.43417346 ,
131
- 1.24153494 , 1.10037457 , 2.55408602 , - 1.47011338 , 0.50824935 ,
132
- - 2.09842977 , 0.74269458 , 0.31025837 , 0.48376623 , 1.74272003 ,
133
- - 0.3975872 , - 0.83735649 , - 0.33724478 , 1.20300335 , 1.40710795 ,
134
- - 0.63740634 , - 0.33976389 , - 0.95412333 , 1.84658352 , 1.2000763 ,
135
- - 1.08264783 , - 1.55367546 , 0.66209331 , 0.6577848 , 0.5727828 ,
136
- 0.30248057 , 0.89674302 , 0.70148518 , 0.56483303 , 1.35161821 ,
137
- 0.06392528 , 0.70670242 , 1.04846633 , 0.54696351 , - 2.49061003 ,
138
- - 1.29925327 , - 1.31906407 , - 0.36650058 , - 1.44809118 , - 0.96224606 ,
139
- - 0.2501728 , - 1.88779999 , 0.35774637 , 1.06917986 , 2.07049617 ,
140
- - 0.18667668 , 0.19360673 , - 0.37665179 , 0.98526962 , 1.03010772 ,
141
- - 0.25348684 , 2.43418902 , 0.89153789 , - 1.02035572 , 1.77851957 ,
142
- 0.6408621 , 0.50163095 , 0.59934511 , 0.73985647 , 0.78719236 ,
143
- - 0.41001864 , - 1.99859554 , 1.53574307 , - 1.71336207 , 1.04355849 ,
144
- 0.21864817 , - 2.03911519 , - 0.42358936 , - 0.49666918 , 1.64327219 ,
145
- - 0.86416032 , 1.10236002 , 0.16396354 , - 0.13313781 , 0.32649281 ,
146
- - 1.01918397 , 0.20525201 , 1.04927506 , 0.98243013 , 2.46970704 ,
147
- - 0.68709777 , 2.05038381 , 0.71417231 , 1.13267395 , - 0.48644823 ]),
127
+ smc .SMC : np .array (
128
+ [ 1.30059573 , - 1. , 1.30059573 , - 1.28860918 , 1.30059573 ,
129
+ - 1.33854363 , 0.98809372 , - 0.22700433 , 0.98809372 , 0.45421367 ,
130
+ 0.53534095 , - 0.0571964 , 0.53534095 , - 0.58075355 , 0.53534095 ,
131
+ - 0.81941713 , 0.21067768 , - 0.77333386 , 0.21067768 , - 0.57169475 ,
132
+ 0.21067768 , - 0.57169475 , 0.71695573 , - 0.1735022 , 0.87048219 ,
133
+ - 0.28469019 , 1.08731483 , 0.08746968 , 1.08059419 , 0.08746968 ,
134
+ 0.31491769 , - 0.17753158 , 0.48834878 , 0.99152949 , - 0.1423678 ,
135
+ 0.1923664 , 0.06791856 , - 0.99708314 , - 0.13981681 , - 0.99708314 ,
136
+ 0.15039906 , - 0.99708314 , 0.15039906 , - 0.6557885 , 0.15039906 ,
137
+ - 0.6557885 , 0.3553436 , - 0.22781864 , 0.3553436 , - 0.86087058 ,
138
+ 0.3553436 , - 1.26758014 , 0.3553436 , - 0.02546953 , 1.992939 ,
139
+ 0.03739508 , 1.992939 , 0.04077929 , 1.47964467 , - 0.79954537 ,
140
+ 1.36470456 , - 1.28038148 , 1.34975939 , - 1.28038148 , 0.5058148 ,
141
+ - 1.28038148 , 0.46681777 , - 1.51635697 , 1.14761057 , - 1.51635697 ,
142
+ 0.70585017 , - 1.51635697 , 0.56298035 , - 1.51635697 , 0.68107999 ,
143
+ - 1.24900543 , 0.68107999 , - 1.58687463 , 0.8251361 , - 0.30236423 ,
144
+ 0.19971902 , 0.0871776 , 0.19971902 , 0.51328569 , 0.19971902 ,
145
+ 0.55526923 , 0.19971902 , 0.16065882 , - 0.87573391 , 0.42539449 ,
146
+ - 0.87573391 , 0.31060689 , - 0.87573391 , 0.48370178 , 0.27495794 ,
147
+ 0.48370178 , 0.37129344 , 0.48370178 , 1.06413954 , 0.48370178 ,
148
+ 1.57177313 , 0.01683961 , 1.75583481 , 0.01683961 , 1.87895941 ,
149
+ 0.49419352 , 1.87895941 , 0.49419352 , 1.58832631 , - 0.02168877 ,
150
+ 1.58832631 , 0.79617759 , 1.41454982 , 0.79617759 , 1.93168471 ,
151
+ 0.78016131 , 1.73345978 , 0.19202933 , 1.62254723 , - 0.22699057 ,
152
+ 1.62254723 , - 0.37699978 , - 0.14380698 , - 1.39915323 , - 0.0647066 ,
153
+ - 1.39915323 , - 0.27796904 , - 0.76046542 , - 1.39097353 , - 0.95882837 ,
154
+ - 2.37809137 , - 0.95882837 , - 0.76288136 , - 0.95882837 , 0.10702519 ,
155
+ 0.6982565 , 0.47017639 , 0.6982565 , 0.9479599 , 1.22950397 ,
156
+ 0.9479599 , 0.06416429 , 0.9479599 , - 0.0761023 , 1.18090459 ,
157
+ - 0.54169371 , 1.18090459 , 0.00447742 , 0.93159018 , 0.77757319 ,
158
+ 0.76557639 , 0.77757319 , 0.41731133 , 0.77757319 , 0.68380287 ,
159
+ 1.08320749 , - 0.43989818 , 1.05260977 , - 0.32229488 , 1.15892126 ,
160
+ 0.38764854 , 1.15892126 , 0.41491972 , 0.71650402 , - 0.27923606 ,
161
+ 0.71650402 , - 1.32493526 , - 0.09371072 , - 0.74606271 , 0.5227973 ,
162
+ - 0.74606271 , 0.63974633 , - 1.41947892 , - 0.96488174 , - 1.294319 ,
163
+ - 0.96488174 , - 1.294319 , - 0.78112189 , - 1.294319 , 0.45821163 ,
164
+ - 1.31499922 , 0.54901984 , - 1.10260234 , 1.12896946 , 0.43768361 ,
165
+ 1.12896946 , - 0.58455279 , 1.12896946 , - 0.58455279 , 1.12896946 ,
166
+ - 0.58455279 , 0.19820143 , - 1.19295628 , - 0.02548627 , - 1.19295628 ,
167
+ 0.50411866 , - 1.19295628 , 0.50411866 , - 1.55631463 , 0.92268245 ]),
148
168
}
149
169
150
170
def setup_class (self ):
@@ -180,9 +200,12 @@ def check_trace(self, step_method):
180
200
x = Normal ('x' , mu = 0 , sd = 1 )
181
201
if step_method .__name__ == 'SMC' :
182
202
trace = smc .sample_smc (n_steps = n_steps ,
183
- step = step_method (random_seed = 1 ),
203
+ n_chains = 2 ,
204
+ start = [{'x' :1. }, {'x' :- 1. }],
205
+ random_seed = 1 ,
184
206
n_jobs = 1 , progressbar = False ,
185
207
homepath = self .temp_dir )
208
+
186
209
elif step_method .__name__ == 'NUTS' :
187
210
step = step_method (scaling = model .test_point )
188
211
trace = sample (0 , tune = n_steps ,
@@ -192,7 +215,6 @@ def check_trace(self, step_method):
192
215
trace = sample (0 , tune = n_steps ,
193
216
discard_tuned_samples = False ,
194
217
step = step_method (), random_seed = 1 )
195
-
196
218
assert_array_almost_equal (
197
219
trace .get_values ('x' ),
198
220
self .master_samples [step_method ],
0 commit comments