33
33
import pymc as pm
34
34
35
35
from pymc .backends import _init_trace
36
- from pymc .backends .base import BaseTrace , MultiTrace , _choose_chains
36
+ from pymc .backends .base import BaseTrace , IBaseTrace , MultiTrace , _choose_chains
37
37
from pymc .blocking import DictToArrayBijection
38
38
from pymc .exceptions import SamplingError
39
39
from pymc .initial_point import PointType , StartDict , make_initial_point_fns_per_chain
71
71
class SamplingIteratorCallback (Protocol ):
72
72
"""Signature of the callable that may be passed to `pm.sample(callable=...)`."""
73
73
74
- def __call__ (self , trace : BaseTrace , draw : Draw ):
74
+ def __call__ (self , trace : IBaseTrace , draw : Draw ):
75
75
pass
76
76
77
77
@@ -657,7 +657,7 @@ def _sample_many(
657
657
* ,
658
658
draws : int ,
659
659
chains : int ,
660
- traces : Sequence [BaseTrace ],
660
+ traces : Sequence [IBaseTrace ],
661
661
start : Sequence [PointType ],
662
662
random_seed : Optional [Sequence [RandomSeed ]],
663
663
step : Step ,
@@ -701,7 +701,7 @@ def _sample(
701
701
start : PointType ,
702
702
draws : int ,
703
703
step : Step ,
704
- trace : BaseTrace ,
704
+ trace : IBaseTrace ,
705
705
tune : int ,
706
706
model : Optional [Model ] = None ,
707
707
callback = None ,
@@ -726,8 +726,8 @@ def _sample(
726
726
The number of samples to draw
727
727
step : function
728
728
Step function
729
- trace : backend, optional
730
- A backend instance .
729
+ trace
730
+ A chain backend to record draws and stats .
731
731
tune : int
732
732
Number of iterations to tune.
733
733
model : Model (optional if in ``with`` context)
@@ -767,7 +767,7 @@ def _iter_sample(
767
767
draws : int ,
768
768
step : Step ,
769
769
start : PointType ,
770
- trace : BaseTrace ,
770
+ trace : IBaseTrace ,
771
771
chain : int = 0 ,
772
772
tune : int = 0 ,
773
773
model : Optional [Model ] = None ,
@@ -785,8 +785,8 @@ def _iter_sample(
785
785
start : dict
786
786
Starting point in parameter space (or partial point).
787
787
Must contain numeric (transformed) initial values for all (transformed) free variables.
788
- trace : backend
789
- A backend instance .
788
+ trace
789
+ A chain backend to record draws and stats .
790
790
chain : int, optional
791
791
Chain number used to store sample in backend.
792
792
tune : int, optional
@@ -852,7 +852,7 @@ def _mp_sample(
852
852
random_seed : Sequence [RandomSeed ],
853
853
start : Sequence [PointType ],
854
854
progressbar : bool = True ,
855
- traces : Sequence [BaseTrace ],
855
+ traces : Sequence [IBaseTrace ],
856
856
model : Optional [Model ] = None ,
857
857
callback : Optional [SamplingIteratorCallback ] = None ,
858
858
mp_ctx = None ,
@@ -879,9 +879,8 @@ def _mp_sample(
879
879
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
880
880
progressbar : bool
881
881
Whether or not to display a progress bar in the command line.
882
- trace : BaseTrace, optional
883
- A backend instance, or None.
884
- If None, the NDArray backend is used.
882
+ traces
883
+ Recording backends for each chain.
885
884
model : Model (optional if in ``with`` context)
886
885
callback
887
886
A function which gets called for every sample from the trace of a chain. The function is
0 commit comments