Skip to content

Commit c5aff17

Browse files
Modify internal sampling functions from BaseTrace to IBaseTrace
1 parent e739546 commit c5aff17

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

pymc/sampling/mcmc.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import pymc as pm
3434

3535
from pymc.backends import _init_trace
36-
from pymc.backends.base import BaseTrace, MultiTrace, _choose_chains
36+
from pymc.backends.base import BaseTrace, IBaseTrace, MultiTrace, _choose_chains
3737
from pymc.blocking import DictToArrayBijection
3838
from pymc.exceptions import SamplingError
3939
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
@@ -71,7 +71,7 @@
7171
class SamplingIteratorCallback(Protocol):
7272
"""Signature of the callable that may be passed to `pm.sample(callable=...)`."""
7373

74-
def __call__(self, trace: BaseTrace, draw: Draw):
74+
def __call__(self, trace: IBaseTrace, draw: Draw):
7575
pass
7676

7777

@@ -657,7 +657,7 @@ def _sample_many(
657657
*,
658658
draws: int,
659659
chains: int,
660-
traces: Sequence[BaseTrace],
660+
traces: Sequence[IBaseTrace],
661661
start: Sequence[PointType],
662662
random_seed: Optional[Sequence[RandomSeed]],
663663
step: Step,
@@ -701,7 +701,7 @@ def _sample(
701701
start: PointType,
702702
draws: int,
703703
step: Step,
704-
trace: BaseTrace,
704+
trace: IBaseTrace,
705705
tune: int,
706706
model: Optional[Model] = None,
707707
callback=None,
@@ -726,8 +726,8 @@ def _sample(
726726
The number of samples to draw
727727
step : function
728728
Step function
729-
trace : backend, optional
730-
A backend instance.
729+
trace
730+
A chain backend to record draws and stats.
731731
tune : int
732732
Number of iterations to tune.
733733
model : Model (optional if in ``with`` context)
@@ -767,7 +767,7 @@ def _iter_sample(
767767
draws: int,
768768
step: Step,
769769
start: PointType,
770-
trace: BaseTrace,
770+
trace: IBaseTrace,
771771
chain: int = 0,
772772
tune: int = 0,
773773
model: Optional[Model] = None,
@@ -785,8 +785,8 @@ def _iter_sample(
785785
start : dict
786786
Starting point in parameter space (or partial point).
787787
Must contain numeric (transformed) initial values for all (transformed) free variables.
788-
trace : backend
789-
A backend instance.
788+
trace
789+
A chain backend to record draws and stats.
790790
chain : int, optional
791791
Chain number used to store sample in backend.
792792
tune : int, optional
@@ -852,7 +852,7 @@ def _mp_sample(
852852
random_seed: Sequence[RandomSeed],
853853
start: Sequence[PointType],
854854
progressbar: bool = True,
855-
traces: Sequence[BaseTrace],
855+
traces: Sequence[IBaseTrace],
856856
model: Optional[Model] = None,
857857
callback: Optional[SamplingIteratorCallback] = None,
858858
mp_ctx=None,
@@ -879,9 +879,8 @@ def _mp_sample(
879879
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
880880
progressbar : bool
881881
Whether or not to display a progress bar in the command line.
882-
trace : BaseTrace, optional
883-
A backend instance, or None.
884-
If None, the NDArray backend is used.
882+
traces
883+
Recording backends for each chain.
885884
model : Model (optional if in ``with`` context)
886885
callback
887886
A function which gets called for every sample from the trace of a chain. The function is

0 commit comments

Comments
 (0)