@@ -525,7 +525,12 @@ def sample(
525
525
_log .info (f"Multiprocess sampling ({ chains } chains in { cores } jobs)" )
526
526
_print_step_hierarchy (step )
527
527
try :
528
- mtrace = _mp_sample (** sample_args , ** parallel_args )
528
+ traces = _mp_sample (** sample_args , ** parallel_args )
529
+ if discard_tuned_samples :
530
+ traces , length = _choose_chains (traces , tune )
531
+ else :
532
+ traces , length = _choose_chains (traces , 0 )
533
+ mtrace = MultiTrace (traces )[:length ]
529
534
except pickle .PickleError :
530
535
_log .warning ("Could not pickle model, sampling singlethreaded." )
531
536
_log .debug ("Pickling error:" , exc_info = True )
@@ -942,10 +947,9 @@ def _mp_sample(
942
947
trace : Optional [BaseTrace ] = None ,
943
948
model = None ,
944
949
callback = None ,
945
- discard_tuned_samples : bool = True ,
946
950
mp_ctx = None ,
947
951
** kwargs ,
948
- ) -> MultiTrace :
952
+ ) -> List [ BaseTrace ] :
949
953
"""Main iteration for multiprocess sampling.
950
954
951
955
Parameters
@@ -980,8 +984,8 @@ def _mp_sample(
980
984
981
985
Returns
982
986
-------
983
- mtrace : pymc.backends.base.MultiTrace
984
- A ``MultiTrace`` object that contains the samples for all chains.
987
+ traces
988
+ All chains.
985
989
"""
986
990
import pymc .sampling .parallel as ps
987
991
@@ -1031,13 +1035,9 @@ def _mp_sample(
1031
1035
multitrace = MultiTrace (traces )
1032
1036
multitrace ._report ._log_summary ()
1033
1037
raise
1034
- return MultiTrace ( traces )
1038
+ return traces
1035
1039
except KeyboardInterrupt :
1036
- if discard_tuned_samples :
1037
- traces , length = _choose_chains (traces , tune )
1038
- else :
1039
- traces , length = _choose_chains (traces , 0 )
1040
- return MultiTrace (traces )[:length ]
1040
+ return traces
1041
1041
finally :
1042
1042
for strace in traces :
1043
1043
strace .close ()
0 commit comments