Skip to content

Commit c26a55e

Browse files
Truncate retrieved arrays to chain length
Closes #38
1 parent f37cd93 commit c26a55e

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

mcbackend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
pass
2121

2222

23-
__version__ = "0.1.1"
23+
__version__ = "0.1.2"

mcbackend/core.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,13 @@ def to_inferencedata(self, **kwargs) -> InferenceData:
188188
posterior = collections.defaultdict(list)
189189
sample_stats = collections.defaultdict(list)
190190
for c, chain in enumerate(chains):
191+
# Every retrieved array is shortened to the previously determined chain length.
192+
# This is needed for database backends which may get inserts inbetween.
193+
clen = chain_lengths[chain.cid]
194+
191195
# Obtain a mask by which draws can be split into warmup/posterior
192196
if "tune" in chain.sample_stats:
193-
tune = chain.get_stats("tune").astype(bool)
197+
tune = chain.get_stats("tune")[:clen].astype(bool)
194198
else:
195199
if c == 0:
196200
_log.warning(
@@ -200,12 +204,12 @@ def to_inferencedata(self, **kwargs) -> InferenceData:
200204

201205
# Split all variables draws into warmup/posterior
202206
for var in variables:
203-
draws = chain.get_draws(var.name)
207+
draws = chain.get_draws(var.name)[:clen]
204208
warmup_posterior[var.name].append(draws[tune])
205209
posterior[var.name].append(draws[~tune])
206210
# Same for sample stats
207211
for svar in self.meta.sample_stats:
208-
stats = chain.get_stats(svar.name)
212+
stats = chain.get_stats(svar.name)[:clen]
209213
warmup_sample_stats[svar.name].append(stats[tune])
210214
sample_stats[svar.name].append(stats[~tune])
211215

0 commit comments

Comments
 (0)