@@ -133,7 +133,7 @@ def check_stat_dtype(self, step, idata):
133
133
continue
134
134
assert idata .sample_stats [stat ].dtype == np .dtype (dtype )
135
135
136
- def step_continuous (self , step_fn , draws ):
136
+ def step_continuous (self , step_fn , draws , chains = 1 , tune = 1000 ):
137
137
start , model , (mu , C ) = mv_simple ()
138
138
unc = np .diag (C ) ** 0.5
139
139
check = (("x" , np .mean , mu , unc / 10 ), ("x" , np .std , unc , unc / 10 ))
@@ -143,14 +143,19 @@ def step_continuous(self, step_fn, draws):
143
143
with warnings .catch_warnings ():
144
144
warnings .filterwarnings ("ignore" , "More chains .* than draws .*" , UserWarning )
145
145
idata = pm .sample (
146
- tune = 1000 ,
146
+ tune = tune ,
147
147
draws = draws ,
148
- chains = 1 ,
148
+ chains = chains ,
149
149
step = step ,
150
150
initvals = start ,
151
151
model = model ,
152
152
random_seed = 1 ,
153
+ discard_tuned_samples = False ,
153
154
)
155
+ assert idata .warmup_posterior .sizes ["chain" ] == chains
156
+ assert idata .warmup_posterior .sizes ["draw" ] == tune
157
+ assert idata .posterior .sizes ["chain" ] == chains
158
+ assert idata .posterior .sizes ["draw" ] == draws
154
159
self .check_stat (check , idata , step .__class__ .__name__ )
155
160
self .check_stat_dtype (idata , step )
156
161
0 commit comments