@@ -479,11 +479,23 @@ def sample(
479
479
model .check_start_vals (ip )
480
480
_check_start_shape (model , ip )
481
481
482
+ # Create trace backends for each chain
483
+ traces = [
484
+ _init_trace (
485
+ expected_length = draws + tune ,
486
+ stats_dtypes = step .stats_dtypes ,
487
+ chain_number = chain_number ,
488
+ trace = trace ,
489
+ model = model ,
490
+ )
491
+ for chain_number in range (chains )
492
+ ]
493
+
482
494
sample_args = {
483
495
"draws" : draws ,
484
496
"step" : step ,
485
497
"start" : initial_points ,
486
- "trace " : trace ,
498
+ "traces " : traces ,
487
499
"chains" : chains ,
488
500
"tune" : tune ,
489
501
"progressbar" : progressbar ,
@@ -524,12 +536,7 @@ def sample(
524
536
_log .info (f"Multiprocess sampling ({ chains } chains in { cores } jobs)" )
525
537
_print_step_hierarchy (step )
526
538
try :
527
- traces = _mp_sample (** sample_args , ** parallel_args )
528
- if discard_tuned_samples :
529
- traces , length = _choose_chains (traces , tune )
530
- else :
531
- traces , length = _choose_chains (traces , 0 )
532
- mtrace = MultiTrace (traces )[:length ]
539
+ _mp_sample (** sample_args , ** parallel_args )
533
540
except pickle .PickleError :
534
541
_log .warning ("Could not pickle model, sampling singlethreaded." )
535
542
_log .debug ("Pickling error:" , exc_info = True )
@@ -544,15 +551,21 @@ def sample(
544
551
if has_population_samplers :
545
552
_log .info (f"Population sampling ({ chains } chains)" )
546
553
_print_step_hierarchy (step )
547
- mtrace = _sample_population (
548
- initial_points = initial_points , parallelize = cores > 1 , ** sample_args
549
- )
554
+ _sample_population (initial_points = initial_points , parallelize = cores > 1 , ** sample_args )
550
555
else :
551
556
_log .info (f"Sequential sampling ({ chains } chains in 1 job)" )
552
557
_print_step_hierarchy (step )
553
- mtrace = _sample_many (** sample_args )
558
+ _sample_many (** sample_args )
554
559
555
560
t_sampling = time .time () - t_start
561
+
562
+ # Wrap chain traces in a MultiTrace
563
+ if discard_tuned_samples :
564
+ traces , length = _choose_chains (traces , tune )
565
+ else :
566
+ traces , length = _choose_chains (traces , 0 )
567
+ mtrace = MultiTrace (traces )[:length ]
568
+
556
569
# count the number of tune/draw iterations that happened
557
570
# ideally via the "tune" statistic, but not all samplers record it!
558
571
if "tune" in mtrace .stat_names :
@@ -639,12 +652,13 @@ def _sample_many(
639
652
* ,
640
653
draws : int ,
641
654
chains : int ,
655
+ traces : Sequence [BaseTrace ],
642
656
start : Sequence [PointType ],
643
657
random_seed : Optional [Sequence [RandomSeed ]],
644
658
step ,
645
659
callback = None ,
646
660
** kwargs ,
647
- ) -> MultiTrace :
661
+ ):
648
662
"""Samples all chains sequentially.
649
663
650
664
Parameters
@@ -659,35 +673,19 @@ def _sample_many(
659
673
A list of seeds, one for each chain
660
674
step: function
661
675
Step function
662
-
663
- Returns
664
- -------
665
- mtrace: MultiTrace
666
- Contains samples of all chains
667
676
"""
668
- traces : List [BaseTrace ] = []
669
677
for i in range (chains ):
670
- trace = _sample (
678
+ _sample (
671
679
draws = draws ,
672
680
chain = i ,
673
681
start = start [i ],
674
682
step = step ,
683
+ trace = traces [i ],
675
684
random_seed = None if random_seed is None else random_seed [i ],
676
685
callback = callback ,
677
686
** kwargs ,
678
687
)
679
- if trace is None :
680
- if len (traces ) == 0 :
681
- raise ValueError ("Sampling stopped before a sample was created." )
682
- else :
683
- break
684
- elif len (trace ) < draws :
685
- if len (traces ) == 0 :
686
- traces .append (trace )
687
- break
688
- else :
689
- traces .append (trace )
690
- return MultiTrace (traces )
688
+ return
691
689
692
690
693
691
def _sample (
@@ -698,12 +696,12 @@ def _sample(
698
696
start : PointType ,
699
697
draws : int ,
700
698
step = None ,
701
- trace : Optional [ BaseTrace ] = None ,
699
+ trace : BaseTrace ,
702
700
tune : int ,
703
701
model : Optional [Model ] = None ,
704
702
callback = None ,
705
703
** kwargs ,
706
- ) -> BaseTrace :
704
+ ) -> None :
707
705
"""Main iteration for singleprocess sampling.
708
706
709
707
Multiple step methods are supported via compound step methods.
@@ -724,16 +722,10 @@ def _sample(
724
722
step : function
725
723
Step function
726
724
trace : backend, optional
727
- A backend instance or None.
728
- If None, the NDArray backend is used.
725
+ A backend instance.
729
726
tune : int
730
727
Number of iterations to tune.
731
728
model : Model (optional if in ``with`` context)
732
-
733
- Returns
734
- -------
735
- strace : BaseTrace
736
- A ``BaseTrace`` object that contains the samples for this chain.
737
729
"""
738
730
skip_first = kwargs .get ("skip_first" , 0 )
739
731
@@ -756,31 +748,27 @@ def _sample(
756
748
else :
757
749
sampling = sampling_gen
758
750
try :
759
- strace = None
760
- for it , (strace , diverging ) in enumerate (sampling ):
751
+ for it , diverging in enumerate (sampling ):
761
752
if it >= skip_first and diverging :
762
753
_pbar_data ["divergences" ] += 1
763
754
if progressbar :
764
755
sampling .comment = _desc .format (** _pbar_data )
765
756
except KeyboardInterrupt :
766
757
pass
767
- if strace is None :
768
- raise Exception ("KeyboardInterrupt happened before the base trace was created." )
769
- return strace
770
758
771
759
772
760
def _iter_sample (
773
761
* ,
774
762
draws : int ,
775
763
step ,
776
764
start : PointType ,
777
- trace : Optional [ BaseTrace ] = None ,
765
+ trace : BaseTrace ,
778
766
chain : int = 0 ,
779
767
tune : int = 0 ,
780
768
model = None ,
781
769
random_seed : RandomSeed = None ,
782
770
callback = None ,
783
- ) -> Iterator [Tuple [ BaseTrace , bool ] ]:
771
+ ) -> Iterator [bool ]:
784
772
"""Generator for sampling one chain. (Used in singleprocess sampling.)
785
773
786
774
Parameters
@@ -792,9 +780,8 @@ def _iter_sample(
792
780
start : dict
793
781
Starting point in parameter space (or partial point).
794
782
Must contain numeric (transformed) initial values for all (transformed) free variables.
795
- trace : backend, optional
796
- A backend instance or None.
797
- If None, the NDArray backend is used.
783
+ trace : backend
784
+ A backend instance.
798
785
chain : int, optional
799
786
Chain number used to store sample in backend.
800
787
tune : int, optional
@@ -804,8 +791,6 @@ def _iter_sample(
804
791
805
792
Yields
806
793
------
807
- strace : BaseTrace
808
- The trace object containing the samples for this chain
809
794
diverging : bool
810
795
Indicates if the draw is divergent. Only available with some samplers.
811
796
"""
@@ -825,14 +810,6 @@ def _iter_sample(
825
810
826
811
point = start
827
812
828
- strace : BaseTrace = _init_trace (
829
- expected_length = draws + tune ,
830
- stats_dtypes = step .stats_dtypes ,
831
- chain_number = chain ,
832
- trace = trace ,
833
- model = model ,
834
- )
835
-
836
813
try :
837
814
step .tune = bool (tune )
838
815
if hasattr (step , "reset_tuning" ):
@@ -846,24 +823,24 @@ def _iter_sample(
846
823
if i == tune :
847
824
step .stop_tuning ()
848
825
point , stats = step .step (point )
849
- strace .record (point , stats )
826
+ trace .record (point , stats )
850
827
log_warning_stats (stats )
851
828
diverging = i > tune and stats and stats [0 ].get ("diverging" )
852
829
if callback is not None :
853
830
callback (
854
- trace = strace ,
831
+ trace = trace ,
855
832
draw = Draw (chain , i == draws , i , i < tune , stats , point ),
856
833
)
857
834
858
- yield strace , diverging
835
+ yield diverging
859
836
except KeyboardInterrupt :
860
- strace .close ()
837
+ trace .close ()
861
838
raise
862
839
except BaseException :
863
- strace .close ()
840
+ trace .close ()
864
841
raise
865
842
else :
866
- strace .close ()
843
+ trace .close ()
867
844
868
845
869
846
def _mp_sample (
@@ -876,12 +853,12 @@ def _mp_sample(
876
853
random_seed : Sequence [RandomSeed ],
877
854
start : Sequence [PointType ],
878
855
progressbar : bool = True ,
879
- trace : Optional [BaseTrace ] = None ,
856
+ traces : Sequence [BaseTrace ],
880
857
model = None ,
881
858
callback = None ,
882
859
mp_ctx = None ,
883
860
** kwargs ,
884
- ) -> List [ BaseTrace ] :
861
+ ) -> None :
885
862
"""Main iteration for multiprocess sampling.
886
863
887
864
Parameters
@@ -913,28 +890,12 @@ def _mp_sample(
913
890
the ``draw.chain`` argument can be used to determine which of the active chains the sample
914
891
is drawn from.
915
892
Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
916
-
917
- Returns
918
- -------
919
- traces
920
- All chains.
921
893
"""
922
894
import pymc .sampling .parallel as ps
923
895
924
896
# We did draws += tune in pm.sample
925
897
draws -= tune
926
898
927
- traces = [
928
- _init_trace (
929
- expected_length = draws + tune ,
930
- stats_dtypes = step .stats_dtypes ,
931
- chain_number = chain_number ,
932
- trace = trace ,
933
- model = model ,
934
- )
935
- for chain_number in range (chains )
936
- ]
937
-
938
899
sampler = ps .ParallelSampler (
939
900
draws = draws ,
940
901
tune = tune ,
@@ -957,7 +918,7 @@ def _mp_sample(
957
918
strace .close ()
958
919
959
920
if callback is not None :
960
- callback (trace = trace , draw = draw )
921
+ callback (trace = strace , draw = draw )
961
922
962
923
except ps .ParallelSamplingError as error :
963
924
strace = traces [error ._chain ]
@@ -967,9 +928,8 @@ def _mp_sample(
967
928
multitrace = MultiTrace (traces )
968
929
multitrace ._report ._log_summary ()
969
930
raise
970
- return traces
971
931
except KeyboardInterrupt :
972
- return traces
932
+ pass
973
933
finally :
974
934
for strace in traces :
975
935
strace .close ()
0 commit comments