Skip to content

Commit d6f1aad

Browse files
Refactor sampling functiongs to use initialized traces
1 parent c0e017d commit d6f1aad

File tree

2 files changed

+64
-124
lines changed

2 files changed

+64
-124
lines changed

pymc/sampling/mcmc.py

Lines changed: 47 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -479,11 +479,23 @@ def sample(
479479
model.check_start_vals(ip)
480480
_check_start_shape(model, ip)
481481

482+
# Create trace backends for each chain
483+
traces = [
484+
_init_trace(
485+
expected_length=draws + tune,
486+
stats_dtypes=step.stats_dtypes,
487+
chain_number=chain_number,
488+
trace=trace,
489+
model=model,
490+
)
491+
for chain_number in range(chains)
492+
]
493+
482494
sample_args = {
483495
"draws": draws,
484496
"step": step,
485497
"start": initial_points,
486-
"trace": trace,
498+
"traces": traces,
487499
"chains": chains,
488500
"tune": tune,
489501
"progressbar": progressbar,
@@ -524,12 +536,7 @@ def sample(
524536
_log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)")
525537
_print_step_hierarchy(step)
526538
try:
527-
traces = _mp_sample(**sample_args, **parallel_args)
528-
if discard_tuned_samples:
529-
traces, length = _choose_chains(traces, tune)
530-
else:
531-
traces, length = _choose_chains(traces, 0)
532-
mtrace = MultiTrace(traces)[:length]
539+
_mp_sample(**sample_args, **parallel_args)
533540
except pickle.PickleError:
534541
_log.warning("Could not pickle model, sampling singlethreaded.")
535542
_log.debug("Pickling error:", exc_info=True)
@@ -544,15 +551,21 @@ def sample(
544551
if has_population_samplers:
545552
_log.info(f"Population sampling ({chains} chains)")
546553
_print_step_hierarchy(step)
547-
mtrace = _sample_population(
548-
initial_points=initial_points, parallelize=cores > 1, **sample_args
549-
)
554+
_sample_population(initial_points=initial_points, parallelize=cores > 1, **sample_args)
550555
else:
551556
_log.info(f"Sequential sampling ({chains} chains in 1 job)")
552557
_print_step_hierarchy(step)
553-
mtrace = _sample_many(**sample_args)
558+
_sample_many(**sample_args)
554559

555560
t_sampling = time.time() - t_start
561+
562+
# Wrap chain traces in a MultiTrace
563+
if discard_tuned_samples:
564+
traces, length = _choose_chains(traces, tune)
565+
else:
566+
traces, length = _choose_chains(traces, 0)
567+
mtrace = MultiTrace(traces)[:length]
568+
556569
# count the number of tune/draw iterations that happened
557570
# ideally via the "tune" statistic, but not all samplers record it!
558571
if "tune" in mtrace.stat_names:
@@ -639,12 +652,13 @@ def _sample_many(
639652
*,
640653
draws: int,
641654
chains: int,
655+
traces: Sequence[BaseTrace],
642656
start: Sequence[PointType],
643657
random_seed: Optional[Sequence[RandomSeed]],
644658
step,
645659
callback=None,
646660
**kwargs,
647-
) -> MultiTrace:
661+
):
648662
"""Samples all chains sequentially.
649663
650664
Parameters
@@ -659,35 +673,19 @@ def _sample_many(
659673
A list of seeds, one for each chain
660674
step: function
661675
Step function
662-
663-
Returns
664-
-------
665-
mtrace: MultiTrace
666-
Contains samples of all chains
667676
"""
668-
traces: List[BaseTrace] = []
669677
for i in range(chains):
670-
trace = _sample(
678+
_sample(
671679
draws=draws,
672680
chain=i,
673681
start=start[i],
674682
step=step,
683+
trace=traces[i],
675684
random_seed=None if random_seed is None else random_seed[i],
676685
callback=callback,
677686
**kwargs,
678687
)
679-
if trace is None:
680-
if len(traces) == 0:
681-
raise ValueError("Sampling stopped before a sample was created.")
682-
else:
683-
break
684-
elif len(trace) < draws:
685-
if len(traces) == 0:
686-
traces.append(trace)
687-
break
688-
else:
689-
traces.append(trace)
690-
return MultiTrace(traces)
688+
return
691689

692690

693691
def _sample(
@@ -698,12 +696,12 @@ def _sample(
698696
start: PointType,
699697
draws: int,
700698
step=None,
701-
trace: Optional[BaseTrace] = None,
699+
trace: BaseTrace,
702700
tune: int,
703701
model: Optional[Model] = None,
704702
callback=None,
705703
**kwargs,
706-
) -> BaseTrace:
704+
) -> None:
707705
"""Main iteration for singleprocess sampling.
708706
709707
Multiple step methods are supported via compound step methods.
@@ -724,16 +722,10 @@ def _sample(
724722
step : function
725723
Step function
726724
trace : backend, optional
727-
A backend instance or None.
728-
If None, the NDArray backend is used.
725+
A backend instance.
729726
tune : int
730727
Number of iterations to tune.
731728
model : Model (optional if in ``with`` context)
732-
733-
Returns
734-
-------
735-
strace : BaseTrace
736-
A ``BaseTrace`` object that contains the samples for this chain.
737729
"""
738730
skip_first = kwargs.get("skip_first", 0)
739731

@@ -756,31 +748,27 @@ def _sample(
756748
else:
757749
sampling = sampling_gen
758750
try:
759-
strace = None
760-
for it, (strace, diverging) in enumerate(sampling):
751+
for it, diverging in enumerate(sampling):
761752
if it >= skip_first and diverging:
762753
_pbar_data["divergences"] += 1
763754
if progressbar:
764755
sampling.comment = _desc.format(**_pbar_data)
765756
except KeyboardInterrupt:
766757
pass
767-
if strace is None:
768-
raise Exception("KeyboardInterrupt happened before the base trace was created.")
769-
return strace
770758

771759

772760
def _iter_sample(
773761
*,
774762
draws: int,
775763
step,
776764
start: PointType,
777-
trace: Optional[BaseTrace] = None,
765+
trace: BaseTrace,
778766
chain: int = 0,
779767
tune: int = 0,
780768
model=None,
781769
random_seed: RandomSeed = None,
782770
callback=None,
783-
) -> Iterator[Tuple[BaseTrace, bool]]:
771+
) -> Iterator[bool]:
784772
"""Generator for sampling one chain. (Used in singleprocess sampling.)
785773
786774
Parameters
@@ -792,9 +780,8 @@ def _iter_sample(
792780
start : dict
793781
Starting point in parameter space (or partial point).
794782
Must contain numeric (transformed) initial values for all (transformed) free variables.
795-
trace : backend, optional
796-
A backend instance or None.
797-
If None, the NDArray backend is used.
783+
trace : backend
784+
A backend instance.
798785
chain : int, optional
799786
Chain number used to store sample in backend.
800787
tune : int, optional
@@ -804,8 +791,6 @@ def _iter_sample(
804791
805792
Yields
806793
------
807-
strace : BaseTrace
808-
The trace object containing the samples for this chain
809794
diverging : bool
810795
Indicates if the draw is divergent. Only available with some samplers.
811796
"""
@@ -825,14 +810,6 @@ def _iter_sample(
825810

826811
point = start
827812

828-
strace: BaseTrace = _init_trace(
829-
expected_length=draws + tune,
830-
stats_dtypes=step.stats_dtypes,
831-
chain_number=chain,
832-
trace=trace,
833-
model=model,
834-
)
835-
836813
try:
837814
step.tune = bool(tune)
838815
if hasattr(step, "reset_tuning"):
@@ -846,24 +823,24 @@ def _iter_sample(
846823
if i == tune:
847824
step.stop_tuning()
848825
point, stats = step.step(point)
849-
strace.record(point, stats)
826+
trace.record(point, stats)
850827
log_warning_stats(stats)
851828
diverging = i > tune and stats and stats[0].get("diverging")
852829
if callback is not None:
853830
callback(
854-
trace=strace,
831+
trace=trace,
855832
draw=Draw(chain, i == draws, i, i < tune, stats, point),
856833
)
857834

858-
yield strace, diverging
835+
yield diverging
859836
except KeyboardInterrupt:
860-
strace.close()
837+
trace.close()
861838
raise
862839
except BaseException:
863-
strace.close()
840+
trace.close()
864841
raise
865842
else:
866-
strace.close()
843+
trace.close()
867844

868845

869846
def _mp_sample(
@@ -876,12 +853,12 @@ def _mp_sample(
876853
random_seed: Sequence[RandomSeed],
877854
start: Sequence[PointType],
878855
progressbar: bool = True,
879-
trace: Optional[BaseTrace] = None,
856+
traces: Sequence[BaseTrace],
880857
model=None,
881858
callback=None,
882859
mp_ctx=None,
883860
**kwargs,
884-
) -> List[BaseTrace]:
861+
) -> None:
885862
"""Main iteration for multiprocess sampling.
886863
887864
Parameters
@@ -913,28 +890,12 @@ def _mp_sample(
913890
the ``draw.chain`` argument can be used to determine which of the active chains the sample
914891
is drawn from.
915892
Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
916-
917-
Returns
918-
-------
919-
traces
920-
All chains.
921893
"""
922894
import pymc.sampling.parallel as ps
923895

924896
# We did draws += tune in pm.sample
925897
draws -= tune
926898

927-
traces = [
928-
_init_trace(
929-
expected_length=draws + tune,
930-
stats_dtypes=step.stats_dtypes,
931-
chain_number=chain_number,
932-
trace=trace,
933-
model=model,
934-
)
935-
for chain_number in range(chains)
936-
]
937-
938899
sampler = ps.ParallelSampler(
939900
draws=draws,
940901
tune=tune,
@@ -957,7 +918,7 @@ def _mp_sample(
957918
strace.close()
958919

959920
if callback is not None:
960-
callback(trace=trace, draw=draw)
921+
callback(trace=strace, draw=draw)
961922

962923
except ps.ParallelSamplingError as error:
963924
strace = traces[error._chain]
@@ -967,9 +928,8 @@ def _mp_sample(
967928
multitrace = MultiTrace(traces)
968929
multitrace._report._log_summary()
969930
raise
970-
return traces
971931
except KeyboardInterrupt:
972-
return traces
932+
pass
973933
finally:
974934
for strace in traces:
975935
strace.close()

0 commit comments

Comments
 (0)