24
24
from collections import defaultdict
25
25
from copy import copy
26
26
from typing import (
27
+ Any ,
28
+ Callable ,
27
29
Dict ,
28
30
Iterable ,
29
31
Iterator ,
@@ -811,12 +813,16 @@ def _sample(
811
813
812
814
trace = copy (trace )
813
815
814
- sampling = _iter_sample (draws , step , start , trace , chain , tune , model , random_seed , callback )
816
+ sampling_gen = _iter_sample (
817
+ draws , step , start , trace , chain , tune , model , random_seed , callback
818
+ )
815
819
_pbar_data = {"chain" : chain , "divergences" : 0 }
816
820
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
817
821
if progressbar :
818
- sampling = progress_bar (sampling , total = draws , display = progressbar )
822
+ sampling = progress_bar (sampling_gen , total = draws , display = progressbar )
819
823
sampling .comment = _desc .format (** _pbar_data )
824
+ else :
825
+ sampling = sampling_gen
820
826
try :
821
827
strace = None
822
828
for it , (strace , diverging ) in enumerate (sampling ):
@@ -826,6 +832,8 @@ def _sample(
826
832
sampling .comment = _desc .format (** _pbar_data )
827
833
except KeyboardInterrupt :
828
834
pass
835
+ if strace is None :
836
+ raise Exception ("KeyboardInterrupt happened before the base trace was created." )
829
837
return strace
830
838
831
839
@@ -1494,10 +1502,12 @@ def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTra
1494
1502
idxs = np .argsort (lengths )
1495
1503
l_sort = np .array (lengths )[idxs ]
1496
1504
1497
- use_until = np .argmax (l_sort * np .arange (1 , l_sort .shape [0 ] + 1 )[::- 1 ])
1505
+ use_until = cast ( int , np .argmax (l_sort * np .arange (1 , l_sort .shape [0 ] + 1 )[::- 1 ]) )
1498
1506
final_length = l_sort [use_until ]
1499
1507
1500
- return [traces [idx ] for idx in idxs [use_until :]], final_length + tune
1508
+ take_idx = cast (Sequence [int ], idxs [use_until :])
1509
+ sliced_traces = [traces [idx ] for idx in take_idx ]
1510
+ return sliced_traces , final_length + tune
1501
1511
1502
1512
1503
1513
def stop_tuning (step ):
@@ -1590,30 +1600,30 @@ def sample_posterior_predictive(
1590
1600
"""
1591
1601
1592
1602
_trace : Union [MultiTrace , PointList ]
1603
+ nchain : int
1593
1604
if isinstance (trace , InferenceData ):
1594
1605
_trace = dataset_to_point_list (trace .posterior )
1606
+ nchain , len_trace = chains_and_samples (trace )
1595
1607
elif isinstance (trace , xarray .Dataset ):
1596
1608
_trace = dataset_to_point_list (trace )
1597
- else :
1609
+ nchain , len_trace = chains_and_samples (trace )
1610
+ elif isinstance (trace , MultiTrace ):
1598
1611
_trace = trace
1612
+ nchain = _trace .nchains
1613
+ len_trace = len (_trace )
1614
+ elif isinstance (trace , list ) and all (isinstance (x , dict ) for x in trace ):
1615
+ _trace = trace
1616
+ nchain = 1
1617
+ len_trace = len (_trace )
1618
+ else :
1619
+ raise TypeError (f"Unsupported type for `trace` argument: { type (trace )} ." )
1599
1620
1600
1621
if keep_size is None :
1601
1622
# This will allow users to set return_inferencedata=False and
1602
1623
# automatically get the old behaviour instead of needing to
1603
1624
# set both return_inferencedata and keep_size to False
1604
1625
keep_size = return_inferencedata
1605
1626
1606
- nchain : int
1607
- len_trace : int
1608
- if isinstance (trace , (InferenceData , xarray .Dataset )):
1609
- nchain , len_trace = chains_and_samples (trace )
1610
- else :
1611
- len_trace = len (_trace )
1612
- try :
1613
- nchain = _trace .nchains
1614
- except AttributeError :
1615
- nchain = 1
1616
-
1617
1627
if keep_size and samples is not None :
1618
1628
raise IncorrectArgumentsError (
1619
1629
"Should not specify both keep_size and samples arguments. "
@@ -1625,7 +1635,7 @@ def sample_posterior_predictive(
1625
1635
if samples is None :
1626
1636
if isinstance (_trace , MultiTrace ):
1627
1637
samples = sum (len (v ) for v in _trace ._straces .values ())
1628
- elif isinstance (_trace , list ) and all ( isinstance ( x , dict ) for x in _trace ) :
1638
+ elif isinstance (_trace , list ):
1629
1639
# this is a list of points
1630
1640
samples = len (_trace )
1631
1641
else :
@@ -1693,6 +1703,7 @@ def sample_posterior_predictive(
1693
1703
else :
1694
1704
inputs , input_names = [], []
1695
1705
else :
1706
+ assert isinstance (_trace , MultiTrace )
1696
1707
output_names = [v .name for v in vars_to_sample if v .name is not None ]
1697
1708
input_names = [
1698
1709
n
@@ -1715,7 +1726,7 @@ def sample_posterior_predictive(
1715
1726
1716
1727
ppc_trace_t = _DefaultTrace (samples )
1717
1728
try :
1718
- if hasattr (_trace , "_straces" ):
1729
+ if isinstance (_trace , MultiTrace ):
1719
1730
# trace dict is unordered, but we want to return ppc samples in
1720
1731
# a predictable ordering, so sort the chain indices
1721
1732
chain_idx_mapping = sorted (_trace ._straces .keys ())
@@ -1750,7 +1761,7 @@ def sample_posterior_predictive(
1750
1761
1751
1762
if not return_inferencedata :
1752
1763
return ppc_trace
1753
- ikwargs = dict (model = model )
1764
+ ikwargs : Dict [ str , Any ] = dict (model = model )
1754
1765
if idata_kwargs :
1755
1766
ikwargs .update (idata_kwargs )
1756
1767
if predictions :
@@ -1881,8 +1892,8 @@ def sample_posterior_predictive_w(
1881
1892
indices = np .random .randint (0 , nchain * len_trace , j )
1882
1893
if nchain > 1 :
1883
1894
chain_idx , point_idx = np .divmod (indices , len_trace )
1884
- for idx in zip (chain_idx , point_idx ):
1885
- trace .append (tr ._straces [idx [ 0 ]] .point (idx [ 1 ] ))
1895
+ for cidx , pidx in zip (chain_idx , point_idx ):
1896
+ trace .append (tr ._straces [cidx ] .point (pidx ))
1886
1897
else :
1887
1898
for idx in indices :
1888
1899
trace .append (tr [idx ])
@@ -1892,12 +1903,12 @@ def sample_posterior_predictive_w(
1892
1903
1893
1904
lengths = list ({np .atleast_1d (observed ).shape for observed in obs })
1894
1905
1906
+ size : List [Optional [Tuple [int , ...]]] = []
1895
1907
if len (lengths ) == 1 :
1896
- size = [None for i in variables ]
1908
+ size = [None ] * len ( variables )
1897
1909
elif len (lengths ) > 2 :
1898
1910
raise ValueError ("Observed variables could not be broadcast together" )
1899
1911
else :
1900
- size = []
1901
1912
x = np .zeros (shape = lengths [0 ])
1902
1913
y = np .zeros (shape = lengths [1 ])
1903
1914
b = np .broadcast (x , y )
@@ -1919,7 +1930,7 @@ def sample_posterior_predictive_w(
1919
1930
indices = progress_bar (indices , total = samples , display = progressbar )
1920
1931
1921
1932
try :
1922
- ppc = defaultdict (list )
1933
+ ppcl : Dict [ str , list ] = defaultdict (list )
1923
1934
for idx in indices :
1924
1935
param = trace [idx ]
1925
1936
var = variables [idx ]
@@ -1932,13 +1943,13 @@ def sample_posterior_predictive_w(
1932
1943
except KeyboardInterrupt :
1933
1944
pass
1934
1945
else :
1935
- ppc = {k : np .asarray (v ) for k , v in ppc .items ()}
1946
+ ppcd = {k : np .asarray (v ) for k , v in ppcl .items ()}
1936
1947
if not return_inferencedata :
1937
- return ppc
1938
- ikwargs = dict (model = models )
1948
+ return ppcd
1949
+ ikwargs : Dict [ str , Any ] = dict (model = models )
1939
1950
if idata_kwargs :
1940
1951
ikwargs .update (idata_kwargs )
1941
- return pm .to_inference_data (posterior_predictive = ppc , ** ikwargs )
1952
+ return pm .to_inference_data (posterior_predictive = ppcd , ** ikwargs )
1942
1953
1943
1954
1944
1955
def sample_prior_predictive (
@@ -2044,7 +2055,7 @@ def sample_prior_predictive(
2044
2055
2045
2056
if not return_inferencedata :
2046
2057
return prior
2047
- ikwargs = dict (model = model )
2058
+ ikwargs : Dict [ str , Any ] = dict (model = model )
2048
2059
if idata_kwargs :
2049
2060
ikwargs .update (idata_kwargs )
2050
2061
return pm .to_inference_data (prior = prior , ** ikwargs )
@@ -2106,10 +2117,11 @@ def draw(
2106
2117
2107
2118
# Single variable output
2108
2119
if not isinstance (vars , (list , tuple )):
2109
- drawn_values = ( draw_fn () for _ in range ( draws ) )
2110
- return np .stack (drawn_values )
2120
+ cast ( Callable [[], np . ndarray ], draw_fn )
2121
+ return np .stack ([ draw_fn () for _ in range ( draws )] )
2111
2122
2112
2123
# Multiple variable output
2124
+ cast (Callable [[], List [np .ndarray ]], draw_fn )
2113
2125
drawn_values = zip (* (draw_fn () for _ in range (draws )))
2114
2126
return [np .stack (v ) for v in drawn_values ]
2115
2127
@@ -2120,7 +2132,7 @@ def _init_jitter(
2120
2132
seeds : Sequence [int ],
2121
2133
jitter : bool ,
2122
2134
jitter_max_retries : int ,
2123
- ) -> PointType :
2135
+ ) -> List [ PointType ] :
2124
2136
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
2125
2137
2126
2138
``model.check_start_vals`` is used to test whether the jittered starting
@@ -2144,7 +2156,7 @@ def _init_jitter(
2144
2156
ipfns = make_initial_point_fns_per_chain (
2145
2157
model = model ,
2146
2158
overrides = initvals ,
2147
- jitter_rvs = set (model .free_RVs ) if jitter else {} ,
2159
+ jitter_rvs = set (model .free_RVs ) if jitter else set () ,
2148
2160
chains = len (seeds ),
2149
2161
)
2150
2162
@@ -2282,6 +2294,7 @@ def init_nuts(
2282
2294
2283
2295
apoints = [DictToArrayBijection .map (point ) for point in initial_points ]
2284
2296
apoints_data = [apoint .data for apoint in apoints ]
2297
+ potential : quadpotential .QuadPotential
2285
2298
2286
2299
if init == "adapt_diag" :
2287
2300
mean = np .mean (apoints_data , axis = 0 )
0 commit comments