Skip to content

Commit f577c2c

Browse files
maresbaseyboldtmichaelosthege
authored
Improve some type hints to bump mypy pin (#6294)
* Add a few missing type imports * Trade assert with assignment to keep mypy happy * Add a few type annotations * Add missing return type for __call__ * Switch comment type declaration to raw * Get operators.py to pass * Fix pymc.backends.report * Fix a bunch of typing issues * Import __future__.annotations to fix "| None" * Update pymc/step_methods/hmc/integration.py * Add __future__.annotations to hmc.py * Remove unused Any import * Don't cast float to np.array * Replace 0 with 0.0 for float zeros * Update pymc/step_methods/hmc/nuts.py Closes #6282 Co-authored-by: Adrian Seyboldt <[email protected]> Co-authored-by: Michael Osthege <[email protected]>
1 parent 5d7283e commit f577c2c

16 files changed

+121
-52
lines changed

conda-envs/environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ dependencies:
3838
- watermark
3939
- polyagamma
4040
- sphinx-remove-toctrees
41-
- mypy=0.982
41+
- mypy=0.990
4242
- types-cachetools

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ dependencies:
2727
- pre-commit>=2.8.0
2828
- pytest-cov>=2.5
2929
- pytest>=3.0
30-
- mypy=0.982
30+
- mypy=0.990
3131
- types-cachetools

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@ dependencies:
3535
- sphinx>=1.5
3636
- watermark
3737
- sphinx-remove-toctrees
38-
- mypy=0.982
38+
- mypy=0.990
3939
- types-cachetools

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ dependencies:
2828
- pre-commit>=2.8.0
2929
- pytest-cov>=2.5
3030
- pytest>=3.0
31-
- mypy=0.982
31+
- mypy=0.990
3232
- types-cachetools

pymc/backends/report.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import dataclasses
1616
import logging
1717

18-
from typing import Optional
18+
from typing import Dict, List, Optional
1919

2020
import arviz
2121

@@ -32,7 +32,7 @@
3232
class SamplerReport:
3333
"""Bundle warnings, convergence stats and metadata of a sampling run."""
3434

35-
def __init__(self):
35+
def __init__(self) -> None:
3636
self._chain_warnings: Dict[int, List[SamplerWarning]] = {}
3737
self._global_warnings: List[SamplerWarning] = []
3838
self._n_tune = None

pymc/blocking.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
1818
Classes for working with subsets of parameters.
1919
"""
20-
import collections
20+
from __future__ import annotations
2121

2222
from functools import partial
23-
from typing import Callable, Dict, Generic, Optional, TypeVar
23+
from typing import Callable, Dict, Generic, NamedTuple, TypeVar
2424

2525
import numpy as np
2626

@@ -32,7 +32,9 @@
3232

3333
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
3434
# each of the raveled variables.
35-
RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info")
35+
class RaveledVars(NamedTuple):
36+
data: np.ndarray
37+
point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...]
3638

3739

3840
class Compose(Generic[T]):
@@ -69,7 +71,7 @@ def map(var_dict: PointType) -> RaveledVars:
6971
@staticmethod
7072
def rmap(
7173
array: RaveledVars,
72-
start_point: Optional[PointType] = None,
74+
start_point: PointType | None = None,
7375
) -> PointType:
7476
"""Map 1D concatenated array to a dictionary of variables in their original spaces.
7577
@@ -100,7 +102,7 @@ def rmap(
100102

101103
@classmethod
102104
def mapf(
103-
cls, f: Callable[[PointType], T], start_point: Optional[PointType] = None
105+
cls, f: Callable[[PointType], T], start_point: PointType | None = None
104106
) -> Callable[[RaveledVars], T]:
105107
"""Create a callable that first maps back to ``dict`` inputs and then applies a function.
106108

pymc/gp/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
# Avoid circular dependency when importing modelcontext
3232
from pymc.distributions.distribution import Distribution
3333

34-
assert Distribution # keep both pylint and black happy
34+
_ = Distribution # keep both pylint and black happy
3535
from pymc.model import modelcontext
3636

3737
JITTER_DEFAULT = 1e-6

pymc/step_methods/hmc/base_hmc.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import logging
1618
import time
1719

1820
from abc import abstractmethod
19-
from collections import namedtuple
20-
from typing import Optional
21+
from typing import Any, NamedTuple
2122

2223
import numpy as np
2324

@@ -29,20 +30,32 @@
2930
from pymc.step_methods import step_sizes
3031
from pymc.step_methods.arraystep import GradientSharedStep
3132
from pymc.step_methods.hmc import integration
33+
from pymc.step_methods.hmc.integration import IntegrationError, State
3234
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
3335
from pymc.tuning import guess_scaling
3436
from pymc.util import get_value_vars_from_user_vars
3537

3638
logger = logging.getLogger("pymc")
3739

38-
HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats")
3940

40-
DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state, state_div")
41+
class DivergenceInfo(NamedTuple):
42+
message: str
43+
exec_info: IntegrationError | None
44+
state: State
45+
state_div: State | None
46+
47+
48+
class HMCStepData(NamedTuple):
49+
end: State
50+
accept_stat: int
51+
divergence_info: DivergenceInfo | None
52+
stats: dict[str, Any]
4153

4254

4355
class BaseHMC(GradientSharedStep):
4456
"""Superclass to implement Hamiltonian/hybrid monte carlo."""
4557

58+
integrator: integration.CpuLeapfrogIntegrator
4659
default_blocked = True
4760

4861
def __init__(
@@ -138,13 +151,13 @@ def __init__(
138151
self._num_divs_sample = 0
139152

140153
@abstractmethod
141-
def _hamiltonian_step(self, start, p0, step_size):
154+
def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData:
142155
"""Compute one Hamiltonian trajectory and return the next state.
143156
144157
Subclasses must overwrite this abstract method and return an `HMCStepData` object.
145158
"""
146159

147-
def astep(self, q0):
160+
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, list[dict[str, Any]]]:
148161
"""Perform a single HMC iteration."""
149162
perf_start = time.perf_counter()
150163
process_start = time.process_time()
@@ -154,6 +167,7 @@ def astep(self, q0):
154167

155168
start = self.integrator.compute_state(q0, p0)
156169

170+
warning: SamplerWarning | None = None
157171
if not np.isfinite(start.energy):
158172
model = self._model
159173
check_test_point_dict = model.point_logps()
@@ -188,7 +202,6 @@ def astep(self, q0):
188202

189203
self.step_adapt.update(hmc_step.accept_stat, adapt_step)
190204
self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
191-
warning: Optional[SamplerWarning] = None
192205
if hmc_step.divergence_info:
193206
info = hmc_step.divergence_info
194207
point = None
@@ -221,7 +234,7 @@ def astep(self, q0):
221234

222235
self.iter_count += 1
223236

224-
stats = {
237+
stats: dict[str, Any] = {
225238
"tune": self.tune,
226239
"diverging": bool(hmc_step.divergence_info),
227240
"perf_counter_diff": perf_end - perf_start,

pymc/step_methods/hmc/hmc.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import Any
18+
1519
import numpy as np
1620

1721
from pymc.stats.convergence import SamplerWarning
@@ -119,7 +123,7 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs):
119123
self.path_length = path_length
120124
self.max_steps = max_steps
121125

122-
def _hamiltonian_step(self, start, p0, step_size):
126+
def _hamiltonian_step(self, start, p0, step_size: float) -> HMCStepData:
123127
n_steps = max(1, int(self.path_length / step_size))
124128
n_steps = min(self.max_steps, n_steps)
125129

@@ -156,7 +160,7 @@ def _hamiltonian_step(self, start, p0, step_size):
156160
end = state
157161
accepted = True
158162

159-
stats = {
163+
stats: dict[str, Any] = {
160164
"path_length": self.path_length,
161165
"n_steps": n_steps,
162166
"accept": accept_stat,

pymc/step_methods/hmc/integration.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,32 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from collections import namedtuple
15+
from typing import NamedTuple
1616

1717
import numpy as np
1818

1919
from scipy import linalg
2020

2121
from pymc.blocking import RaveledVars
22+
from pymc.step_methods.hmc.quadpotential import QuadPotential
2223

23-
State = namedtuple("State", "q, p, v, q_grad, energy, model_logp, index_in_trajectory")
24+
25+
class State(NamedTuple):
26+
q: RaveledVars
27+
p: RaveledVars
28+
v: np.ndarray
29+
q_grad: np.ndarray
30+
energy: float
31+
model_logp: float
32+
index_in_trajectory: int
2433

2534

2635
class IntegrationError(RuntimeError):
2736
pass
2837

2938

3039
class CpuLeapfrogIntegrator:
31-
def __init__(self, potential, logp_dlogp_func):
40+
def __init__(self, potential: QuadPotential, logp_dlogp_func):
3241
"""Leapfrog integrator using CPU."""
3342
self._potential = potential
3443
self._logp_dlogp_func = logp_dlogp_func
@@ -39,14 +48,14 @@ def __init__(self, potential, logp_dlogp_func):
3948
"don't match." % (self._potential.dtype, self._dtype)
4049
)
4150

42-
def compute_state(self, q, p):
51+
def compute_state(self, q: RaveledVars, p: RaveledVars):
4352
"""Compute Hamiltonian functions using a position and momentum."""
4453
if q.data.dtype != self._dtype or p.data.dtype != self._dtype:
4554
raise ValueError("Invalid dtype. Must be %s" % self._dtype)
4655

4756
logp, dlogp = self._logp_dlogp_func(q)
4857

49-
v = self._potential.velocity(p.data)
58+
v = self._potential.velocity(p.data, out=None)
5059
kinetic = self._potential.energy(p.data, velocity=v)
5160
energy = kinetic - logp
5261
return State(q, p, v, dlogp, energy, logp, 0)

0 commit comments

Comments
 (0)