Skip to content

Commit 08b6b01

Browse files
Fix more typing problems
1 parent 97f0f79 commit 08b6b01

File tree

9 files changed

+70
-45
lines changed

9 files changed

+70
-45
lines changed

pymc/aesaraf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,9 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
941941
)
942942

943943

944-
def compile_pymc(inputs, outputs, mode=None, **kwargs):
944+
def compile_pymc(
945+
inputs, outputs, mode=None, **kwargs
946+
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
945947
"""Use ``aesara.function`` with specialized pymc rewrites always enabled.
946948
947949
Included rewrites

pymc/backends/arviz.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Any,
88
Dict,
99
Iterable,
10+
Mapping,
1011
Optional,
1112
Tuple,
1213
Union,
@@ -532,8 +533,8 @@ def to_inference_data(self):
532533
def to_inference_data(
533534
trace: Optional["MultiTrace"] = None,
534535
*,
535-
prior: Optional[Dict[str, Any]] = None,
536-
posterior_predictive: Optional[Dict[str, Any]] = None,
536+
prior: Optional[Mapping[str, Any]] = None,
537+
posterior_predictive: Optional[Mapping[str, Any]] = None,
537538
log_likelihood: Union[bool, Iterable[str]] = True,
538539
coords: Optional[CoordSpec] = None,
539540
dims: Optional[DimSpec] = None,

pymc/func_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import warnings
1515

16-
from typing import Dict, Optional
16+
from typing import Callable, Dict, Optional, Union
1717

1818
import aesara.tensor as aet
1919
import numpy as np
@@ -129,6 +129,7 @@ def find_constrained_prior(
129129
cdf_error = (pm.math.exp(logcdf_upper) - pm.math.exp(logcdf_lower)) - mass
130130
cdf_error_fn = pm.aesaraf.compile_pymc([dist_params], cdf_error, allow_input_downcast=True)
131131

132+
jac: Union[str, Callable]
132133
try:
133134
aesara_jac = pm.gradient(cdf_error, [dist_params])
134135
jac = pm.aesaraf.compile_pymc([dist_params], aesara_jac, allow_input_downcast=True)

pymc/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,8 +1443,9 @@ def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]]
14431443
if dims is not None:
14441444
if isinstance(dims, str):
14451445
dims = (dims,)
1446-
if any(dim not in self.coords and dim is not None for dim in dims):
1447-
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
1446+
for dim in dims:
1447+
if dim not in self.coords and dim is not None:
1448+
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
14481449
if any(var.name == dim for dim in dims):
14491450
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
14501451
self._RV_dims[var.name] = dims

pymc/sampling.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from collections import defaultdict
2525
from copy import copy
2626
from typing import (
27+
Any,
28+
Callable,
2729
Dict,
2830
Iterable,
2931
Iterator,
@@ -811,12 +813,16 @@ def _sample(
811813

812814
trace = copy(trace)
813815

814-
sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed, callback)
816+
sampling_gen = _iter_sample(
817+
draws, step, start, trace, chain, tune, model, random_seed, callback
818+
)
815819
_pbar_data = {"chain": chain, "divergences": 0}
816820
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
817821
if progressbar:
818-
sampling = progress_bar(sampling, total=draws, display=progressbar)
822+
sampling = progress_bar(sampling_gen, total=draws, display=progressbar)
819823
sampling.comment = _desc.format(**_pbar_data)
824+
else:
825+
sampling = sampling_gen
820826
try:
821827
strace = None
822828
for it, (strace, diverging) in enumerate(sampling):
@@ -826,6 +832,8 @@ def _sample(
826832
sampling.comment = _desc.format(**_pbar_data)
827833
except KeyboardInterrupt:
828834
pass
835+
if strace is None:
836+
raise Exception("KeyboardInterrupt happened before the base trace was created.")
829837
return strace
830838

831839

@@ -1494,10 +1502,12 @@ def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTra
14941502
idxs = np.argsort(lengths)
14951503
l_sort = np.array(lengths)[idxs]
14961504

1497-
use_until = np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1)[::-1])
1505+
use_until = cast(int, np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1)[::-1]))
14981506
final_length = l_sort[use_until]
14991507

1500-
return [traces[idx] for idx in idxs[use_until:]], final_length + tune
1508+
take_idx = cast(Sequence[int], idxs[use_until:])
1509+
sliced_traces = [traces[idx] for idx in take_idx]
1510+
return sliced_traces, final_length + tune
15011511

15021512

15031513
def stop_tuning(step):
@@ -1590,30 +1600,30 @@ def sample_posterior_predictive(
15901600
"""
15911601

15921602
_trace: Union[MultiTrace, PointList]
1603+
nchain: int
15931604
if isinstance(trace, InferenceData):
15941605
_trace = dataset_to_point_list(trace.posterior)
1606+
nchain, len_trace = chains_and_samples(trace)
15951607
elif isinstance(trace, xarray.Dataset):
15961608
_trace = dataset_to_point_list(trace)
1597-
else:
1609+
nchain, len_trace = chains_and_samples(trace)
1610+
elif isinstance(trace, MultiTrace):
15981611
_trace = trace
1612+
nchain = _trace.nchains
1613+
len_trace = len(_trace)
1614+
elif isinstance(trace, list) and all(isinstance(x, dict) for x in trace):
1615+
_trace = trace
1616+
nchain = 1
1617+
len_trace = len(_trace)
1618+
else:
1619+
raise TypeError(f"Unsupported type for `trace` argument: {type(trace)}.")
15991620

16001621
if keep_size is None:
16011622
# This will allow users to set return_inferencedata=False and
16021623
# automatically get the old behaviour instead of needing to
16031624
# set both return_inferencedata and keep_size to False
16041625
keep_size = return_inferencedata
16051626

1606-
nchain: int
1607-
len_trace: int
1608-
if isinstance(trace, (InferenceData, xarray.Dataset)):
1609-
nchain, len_trace = chains_and_samples(trace)
1610-
else:
1611-
len_trace = len(_trace)
1612-
try:
1613-
nchain = _trace.nchains
1614-
except AttributeError:
1615-
nchain = 1
1616-
16171627
if keep_size and samples is not None:
16181628
raise IncorrectArgumentsError(
16191629
"Should not specify both keep_size and samples arguments. "
@@ -1625,7 +1635,7 @@ def sample_posterior_predictive(
16251635
if samples is None:
16261636
if isinstance(_trace, MultiTrace):
16271637
samples = sum(len(v) for v in _trace._straces.values())
1628-
elif isinstance(_trace, list) and all(isinstance(x, dict) for x in _trace):
1638+
elif isinstance(_trace, list):
16291639
# this is a list of points
16301640
samples = len(_trace)
16311641
else:
@@ -1693,6 +1703,7 @@ def sample_posterior_predictive(
16931703
else:
16941704
inputs, input_names = [], []
16951705
else:
1706+
assert isinstance(_trace, MultiTrace)
16961707
output_names = [v.name for v in vars_to_sample if v.name is not None]
16971708
input_names = [
16981709
n
@@ -1715,7 +1726,7 @@ def sample_posterior_predictive(
17151726

17161727
ppc_trace_t = _DefaultTrace(samples)
17171728
try:
1718-
if hasattr(_trace, "_straces"):
1729+
if isinstance(_trace, MultiTrace):
17191730
# trace dict is unordered, but we want to return ppc samples in
17201731
# a predictable ordering, so sort the chain indices
17211732
chain_idx_mapping = sorted(_trace._straces.keys())
@@ -1750,7 +1761,7 @@ def sample_posterior_predictive(
17501761

17511762
if not return_inferencedata:
17521763
return ppc_trace
1753-
ikwargs = dict(model=model)
1764+
ikwargs: Dict[str, Any] = dict(model=model)
17541765
if idata_kwargs:
17551766
ikwargs.update(idata_kwargs)
17561767
if predictions:
@@ -1881,8 +1892,8 @@ def sample_posterior_predictive_w(
18811892
indices = np.random.randint(0, nchain * len_trace, j)
18821893
if nchain > 1:
18831894
chain_idx, point_idx = np.divmod(indices, len_trace)
1884-
for idx in zip(chain_idx, point_idx):
1885-
trace.append(tr._straces[idx[0]].point(idx[1]))
1895+
for cidx, pidx in zip(chain_idx, point_idx):
1896+
trace.append(tr._straces[cidx].point(pidx))
18861897
else:
18871898
for idx in indices:
18881899
trace.append(tr[idx])
@@ -1892,12 +1903,12 @@ def sample_posterior_predictive_w(
18921903

18931904
lengths = list({np.atleast_1d(observed).shape for observed in obs})
18941905

1906+
size: List[Optional[Tuple[int, ...]]] = []
18951907
if len(lengths) == 1:
1896-
size = [None for i in variables]
1908+
size = [None] * len(variables)
18971909
elif len(lengths) > 2:
18981910
raise ValueError("Observed variables could not be broadcast together")
18991911
else:
1900-
size = []
19011912
x = np.zeros(shape=lengths[0])
19021913
y = np.zeros(shape=lengths[1])
19031914
b = np.broadcast(x, y)
@@ -1919,7 +1930,7 @@ def sample_posterior_predictive_w(
19191930
indices = progress_bar(indices, total=samples, display=progressbar)
19201931

19211932
try:
1922-
ppc = defaultdict(list)
1933+
ppcl: Dict[str, list] = defaultdict(list)
19231934
for idx in indices:
19241935
param = trace[idx]
19251936
var = variables[idx]
@@ -1932,13 +1943,13 @@ def sample_posterior_predictive_w(
19321943
except KeyboardInterrupt:
19331944
pass
19341945
else:
1935-
ppc = {k: np.asarray(v) for k, v in ppc.items()}
1946+
ppcd = {k: np.asarray(v) for k, v in ppcl.items()}
19361947
if not return_inferencedata:
1937-
return ppc
1938-
ikwargs = dict(model=models)
1948+
return ppcd
1949+
ikwargs: Dict[str, Any] = dict(model=models)
19391950
if idata_kwargs:
19401951
ikwargs.update(idata_kwargs)
1941-
return pm.to_inference_data(posterior_predictive=ppc, **ikwargs)
1952+
return pm.to_inference_data(posterior_predictive=ppcd, **ikwargs)
19421953

19431954

19441955
def sample_prior_predictive(
@@ -2044,7 +2055,7 @@ def sample_prior_predictive(
20442055

20452056
if not return_inferencedata:
20462057
return prior
2047-
ikwargs = dict(model=model)
2058+
ikwargs: Dict[str, Any] = dict(model=model)
20482059
if idata_kwargs:
20492060
ikwargs.update(idata_kwargs)
20502061
return pm.to_inference_data(prior=prior, **ikwargs)
@@ -2106,10 +2117,11 @@ def draw(
21062117

21072118
# Single variable output
21082119
if not isinstance(vars, (list, tuple)):
2109-
drawn_values = (draw_fn() for _ in range(draws))
2110-
return np.stack(drawn_values)
2120+
cast(Callable[[], np.ndarray], draw_fn)
2121+
return np.stack([draw_fn() for _ in range(draws)])
21112122

21122123
# Multiple variable output
2124+
cast(Callable[[], List[np.ndarray]], draw_fn)
21132125
drawn_values = zip(*(draw_fn() for _ in range(draws)))
21142126
return [np.stack(v) for v in drawn_values]
21152127

@@ -2120,7 +2132,7 @@ def _init_jitter(
21202132
seeds: Sequence[int],
21212133
jitter: bool,
21222134
jitter_max_retries: int,
2123-
) -> PointType:
2135+
) -> List[PointType]:
21242136
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
21252137
21262138
``model.check_start_vals`` is used to test whether the jittered starting
@@ -2144,7 +2156,7 @@ def _init_jitter(
21442156
ipfns = make_initial_point_fns_per_chain(
21452157
model=model,
21462158
overrides=initvals,
2147-
jitter_rvs=set(model.free_RVs) if jitter else {},
2159+
jitter_rvs=set(model.free_RVs) if jitter else set(),
21482160
chains=len(seeds),
21492161
)
21502162

@@ -2282,6 +2294,7 @@ def init_nuts(
22822294

22832295
apoints = [DictToArrayBijection.map(point) for point in initial_points]
22842296
apoints_data = [apoint.data for apoint in apoints]
2297+
potential: quadpotential.QuadPotential
22852298

22862299
if init == "adapt_diag":
22872300
mean = np.mean(apoints_data, axis=0)

pymc/smc/smc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import abc
1515

1616
from abc import ABC
17-
from typing import Dict
17+
from typing import Dict, cast
1818

1919
import aesara.tensor as at
2020
import numpy as np
@@ -173,15 +173,15 @@ def __init__(
173173
self.resampling_indexes = None
174174
self.weights = np.ones(self.draws) / self.draws
175175

176-
def initialize_population(self) -> Dict[str, NDArray]:
176+
def initialize_population(self) -> Dict[str, np.ndarray]:
177177
"""Create an initial population from the prior distribution"""
178-
179-
return sample_prior_predictive(
178+
result = sample_prior_predictive(
180179
self.draws,
181180
var_names=[v.name for v in self.model.unobserved_value_vars],
182181
model=self.model,
183182
return_inferencedata=False,
184183
)
184+
return cast(Dict[str, np.ndarray], result)
185185

186186
def _initialize_kernel(self):
187187
"""Create variables and logp function necessary to run kernel

pymc/step_methods/arraystep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class Competence(IntEnum):
5050
class BlockedStep(ABC):
5151

5252
generates_stats = False
53-
stats_dtypes: List[Dict[str, np.dtype]] = []
53+
stats_dtypes: List[Dict[str, type]] = []
5454
vars: List[Variable] = []
5555

5656
def __new__(cls, *args, **kwargs):

pymc/tests/test_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ def test_exceptions(self, caplog):
635635

636636
# test wrong type argument
637637
bad_trace = {"mu": stats.norm.rvs(size=1000)}
638-
with pytest.raises(TypeError):
638+
with pytest.raises(TypeError, match="type for `trace`"):
639639
ppc = pm.sample_posterior_predictive(bad_trace)
640640

641641
def test_vector_observed(self):

scripts/run_mypy.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,19 @@
5050
pymc/ode/utils.py
5151
pymc/parallel_sampling.py
5252
pymc/plots/__init__.py
53+
pymc/sampling.py
5354
pymc/smc/__init__.py
5455
pymc/smc/sample_smc.py
56+
pymc/smc/smc.py
5557
pymc/stats/__init__.py
5658
pymc/step_methods/__init__.py
5759
pymc/step_methods/compound.py
5860
pymc/step_methods/elliptical_slice.py
5961
pymc/step_methods/hmc/__init__.py
6062
pymc/step_methods/hmc/base_hmc.py
63+
pymc/step_methods/hmc/hmc.py
6164
pymc/step_methods/hmc/integration.py
65+
pymc/step_methods/hmc/nuts.py
6266
pymc/step_methods/hmc/quadpotential.py
6367
pymc/step_methods/slicer.py
6468
pymc/step_methods/step_sizes.py
@@ -159,6 +163,9 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]):
159163
print(f"{len(unexpected_passing)} files unexpectedly passed the type checks:")
160164
print("\n".join(sorted(map(str, unexpected_passing))))
161165
print("This is good news! Go to scripts/run-mypy.py and add them to the list.")
166+
if all_files.issubset(passing):
167+
print("WOW! All files are passing the mypy type checks!")
168+
print("scripts\\run_mypy.py may no longer be needed.")
162169
sys.exit(1)
163170
return
164171

0 commit comments

Comments
 (0)