Skip to content

Commit 925c3c9

Browse files
Remove discard_tuned_samples from _mp_sample
1 parent 4c020e7 commit 925c3c9

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

pymc/sampling/mcmc.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,12 @@ def sample(
525525
_log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)")
526526
_print_step_hierarchy(step)
527527
try:
528-
mtrace = _mp_sample(**sample_args, **parallel_args)
528+
traces = _mp_sample(**sample_args, **parallel_args)
529+
if discard_tuned_samples:
530+
traces, length = _choose_chains(traces, tune)
531+
else:
532+
traces, length = _choose_chains(traces, 0)
533+
mtrace = MultiTrace(traces)[:length]
529534
except pickle.PickleError:
530535
_log.warning("Could not pickle model, sampling singlethreaded.")
531536
_log.debug("Pickling error:", exc_info=True)
@@ -942,10 +947,9 @@ def _mp_sample(
942947
trace: Optional[BaseTrace] = None,
943948
model=None,
944949
callback=None,
945-
discard_tuned_samples: bool = True,
946950
mp_ctx=None,
947951
**kwargs,
948-
) -> MultiTrace:
952+
) -> List[BaseTrace]:
949953
"""Main iteration for multiprocess sampling.
950954
951955
Parameters
@@ -980,8 +984,8 @@ def _mp_sample(
980984
981985
Returns
982986
-------
983-
mtrace : pymc.backends.base.MultiTrace
984-
A ``MultiTrace`` object that contains the samples for all chains.
987+
traces
988+
All chains.
985989
"""
986990
import pymc.sampling.parallel as ps
987991

@@ -1031,13 +1035,9 @@ def _mp_sample(
10311035
multitrace = MultiTrace(traces)
10321036
multitrace._report._log_summary()
10331037
raise
1034-
return MultiTrace(traces)
1038+
return traces
10351039
except KeyboardInterrupt:
1036-
if discard_tuned_samples:
1037-
traces, length = _choose_chains(traces, tune)
1038-
else:
1039-
traces, length = _choose_chains(traces, 0)
1040-
return MultiTrace(traces)[:length]
1040+
return traces
10411041
finally:
10421042
for strace in traces:
10431043
strace.close()

0 commit comments

Comments
 (0)