21
21
import warnings
22
22
23
23
from collections import defaultdict
24
- from typing import Iterator , List , Optional , Sequence , Tuple , Union
24
+ from typing import Any , Dict , Iterator , List , Optional , Sequence , Tuple , Union
25
25
26
26
import numpy as np
27
27
import pytensor .gradient as tg
28
28
29
29
from arviz import InferenceData
30
30
from fastprogress .fastprogress import progress_bar
31
- from typing_extensions import TypeAlias
31
+ from typing_extensions import Protocol , TypeAlias
32
32
33
33
import pymc as pm
34
34
64
64
Step : TypeAlias = Union [BlockedStep , CompoundStep ]
65
65
66
66
67
+ class SamplingIteratorCallback (Protocol ):
68
+ """Signature of the callable that may be passed to `pm.sample(callable=...)`."""
69
+
70
+ def __call__ (self , trace : BaseTrace , draw : Draw ):
71
+ pass
72
+
73
+
67
74
_log = logging .getLogger ("pymc" )
68
75
69
76
@@ -221,7 +228,7 @@ def sample(
221
228
cores : Optional [int ] = None ,
222
229
tune : int = 1000 ,
223
230
progressbar : bool = True ,
224
- model = None ,
231
+ model : Optional [ Model ] = None ,
225
232
random_seed : RandomState = None ,
226
233
discard_tuned_samples : bool = True ,
227
234
compute_convergence_checks : bool = True ,
@@ -599,7 +606,7 @@ def sample(
599
606
600
607
idata = None
601
608
if compute_convergence_checks or return_inferencedata :
602
- ikwargs = dict (model = model , save_warmup = not discard_tuned_samples )
609
+ ikwargs : Dict [ str , Any ] = dict (model = model , save_warmup = not discard_tuned_samples )
603
610
if idata_kwargs :
604
611
ikwargs .update (idata_kwargs )
605
612
idata = pm .to_inference_data (mtrace , ** ikwargs )
@@ -655,8 +662,8 @@ def _sample_many(
655
662
traces : Sequence [BaseTrace ],
656
663
start : Sequence [PointType ],
657
664
random_seed : Optional [Sequence [RandomSeed ]],
658
- step ,
659
- callback = None ,
665
+ step : Step ,
666
+ callback : Optional [ SamplingIteratorCallback ] = None ,
660
667
** kwargs ,
661
668
):
662
669
"""Samples all chains sequentially.
@@ -695,7 +702,7 @@ def _sample(
695
702
random_seed : RandomSeed ,
696
703
start : PointType ,
697
704
draws : int ,
698
- step = None ,
705
+ step : Step ,
699
706
trace : BaseTrace ,
700
707
tune : int ,
701
708
model : Optional [Model ] = None ,
@@ -760,14 +767,14 @@ def _sample(
760
767
def _iter_sample (
761
768
* ,
762
769
draws : int ,
763
- step ,
770
+ step : Step ,
764
771
start : PointType ,
765
772
trace : BaseTrace ,
766
773
chain : int = 0 ,
767
774
tune : int = 0 ,
768
- model = None ,
775
+ model : Optional [ Model ] = None ,
769
776
random_seed : RandomSeed = None ,
770
- callback = None ,
777
+ callback : Optional [ SamplingIteratorCallback ] = None ,
771
778
) -> Iterator [bool ]:
772
779
"""Generator for sampling one chain. (Used in singleprocess sampling.)
773
780
@@ -803,19 +810,13 @@ def _iter_sample(
803
810
if random_seed is not None :
804
811
np .random .seed (random_seed )
805
812
806
- try :
807
- step = CompoundStep (step )
808
- except TypeError :
809
- pass
810
-
811
813
point = start
812
814
813
815
try :
814
816
step .tune = bool (tune )
815
817
if hasattr (step , "reset_tuning" ):
816
818
step .reset_tuning ()
817
819
for i in range (draws ):
818
- stats = None
819
820
diverging = False
820
821
821
822
if i == 0 and hasattr (step , "iter_count" ):
@@ -825,7 +826,7 @@ def _iter_sample(
825
826
point , stats = step .step (point )
826
827
trace .record (point , stats )
827
828
log_warning_stats (stats )
828
- diverging = i > tune and stats and stats [0 ].get ("diverging" )
829
+ diverging = i > tune and len ( stats ) > 0 and ( stats [0 ].get ("diverging" ) == True )
829
830
if callback is not None :
830
831
callback (
831
832
trace = trace ,
@@ -854,8 +855,8 @@ def _mp_sample(
854
855
start : Sequence [PointType ],
855
856
progressbar : bool = True ,
856
857
traces : Sequence [BaseTrace ],
857
- model = None ,
858
- callback = None ,
858
+ model : Optional [ Model ] = None ,
859
+ callback : Optional [ SamplingIteratorCallback ] = None ,
859
860
mp_ctx = None ,
860
861
** kwargs ,
861
862
) -> None :
@@ -884,7 +885,7 @@ def _mp_sample(
884
885
A backend instance, or None.
885
886
If None, the NDArray backend is used.
886
887
model : Model (optional if in ``with`` context)
887
- callback : Callable
888
+ callback
888
889
A function which gets called for every sample from the trace of a chain. The function is
889
890
called with the trace and the current draw and will contain all samples for a single trace.
890
891
the ``draw.chain`` argument can be used to determine which of the active chains the sample
@@ -994,7 +995,7 @@ def init_nuts(
994
995
init : str = "auto" ,
995
996
chains : int = 1 ,
996
997
n_init : int = 500_000 ,
997
- model = None ,
998
+ model : Optional [ Model ] = None ,
998
999
random_seed : RandomSeed = None ,
999
1000
progressbar = True ,
1000
1001
jitter_max_retries : int = 10 ,
0 commit comments