Skip to content

Commit b33fe73

Browse files
Add mypy config and type hints
Some initial values were changed to tighten type information.
1 parent f4cfaa1 commit b33fe73

File tree

8 files changed

+89
-40
lines changed

8 files changed

+89
-40
lines changed

mcbackend/adapters/pymc.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
The only PyMC dependency is on the ``BaseTrace`` abstract base class.
55
"""
6-
from typing import Dict, List, Sequence, Tuple
6+
from typing import Dict, List, Optional, Sequence, Tuple
77

88
import hagelkorn
99
import numpy
@@ -104,22 +104,27 @@ def __init__( # pylint: disable=W0622
104104
vars=None,
105105
test_point=None,
106106
):
107-
self.chain = None
107+
self.chain: int = -1
108108
super().__init__(name, model, vars, test_point)
109109
self.run_id = hagelkorn.random(digits=6)
110110
print(f"Backend run id: {self.run_id}")
111111
self._backend: Backend = backend
112112

113113
# Sessions created from the underlying backend
114-
self._run: Run = None
115-
self._chain: Chain = None
116-
self._stat_groups: List[List[Tuple[str, str]]] = None
114+
self._run: Optional[Run] = None
115+
self._chain: Optional[Chain] = None
116+
self._stat_groups: List[List[Tuple[str, str]]] = []
117117
self._length: int = 0
118118

119119
def __len__(self) -> int:
120120
return self._length
121121

122-
def setup(self, draws, chain, sampler_vars=None) -> None:
122+
def setup(
123+
self,
124+
draws: int,
125+
chain: int,
126+
sampler_vars: Optional[List[Dict[str, numpy.dtype]]] = None,
127+
) -> None:
123128
super().setup(draws, chain, sampler_vars)
124129
self.chain = chain
125130

@@ -131,7 +136,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
131136
name,
132137
str(self.var_dtypes[name]),
133138
list(self.var_shapes[name]),
134-
dims=list(self.model.RV_dims[name]) if name in self.model.RV_dims else None,
139+
dims=list(self.model.RV_dims[name]) if name in self.model.RV_dims else [],
135140
is_deterministic=(name not in free_rv_names),
136141
)
137142
for name in self.varnames
@@ -196,15 +201,23 @@ def record(self, point, sampler_states=None):
196201
return
197202

198203
def get_values(self, varname, burn=0, thin=1) -> numpy.ndarray:
204+
if self._chain is None:
205+
raise Exception("Trace setup was not completed. Call `.setup()` first.")
199206
return self._chain.get_draws(varname)[burn::thin]
200207

201208
def _get_stats(self, varname, burn=0, thin=1) -> numpy.ndarray:
209+
if self._chain is None:
210+
raise Exception("Trace setup was not completed. Call `.setup()` first.")
202211
return self._chain.get_stats(varname)[burn::thin]
203212

204213
def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
214+
if self._chain is None:
215+
raise Exception("Trace setup was not completed. Call `.setup()` first.")
205216
return self._get_stats(f"sampler_{sampler_idx}__{stat_name}", burn, thin)
206217

207218
def point(self, idx: int):
219+
if self._chain is None:
220+
raise Exception("Trace setup was not completed. Call `.setup()` first.")
208221
return self._chain.get_draws_at(idx, self.var_names)
209222

210223
def as_readonly(self) -> ReadOnlyTrace:

mcbackend/backends/clickhouse.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import time
77
from datetime import datetime, timezone
8-
from typing import Callable, Dict, Optional, Sequence, Tuple
8+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
99

1010
import clickhouse_driver
1111
import numpy
@@ -80,13 +80,13 @@ def create_chain_table(client: clickhouse_driver.Client, meta: ChainMeta, rmeta:
8080
)
8181
columns.append(column_spec_for(var, is_stat=True))
8282
assert len(set(columns)) == len(columns), columns
83-
columns = ",\n ".join(columns)
83+
cols = ",\n ".join(columns)
8484

8585
query = f"""
8686
CREATE TABLE {cid}
8787
(
8888
`_draw_idx` UInt64,
89-
{columns}
89+
{cols}
9090
)
9191
ENGINE TinyLog();
9292
"""
@@ -110,8 +110,8 @@ def __init__(
110110
self._client = client
111111
# The following attributes belong to the batched insert mechanism.
112112
# Inserting in batches is much faster than inserting single rows.
113-
self._insert_query = None
114-
self._insert_queue = []
113+
self._insert_query: str = ""
114+
self._insert_queue: List[Dict[str, Any]] = []
115115
self._last_insert = time.time()
116116
self._insert_interval = insert_interval
117117
self._insert_every = insert_every
@@ -121,7 +121,7 @@ def append(
121121
self, draw: Dict[str, numpy.ndarray], stats: Optional[Dict[str, numpy.ndarray]] = None
122122
):
123123
stat = {f"__stat_{sname}": svals for sname, svals in (stats or {}).items()}
124-
params = {"_draw_idx": self._draw_idx, **draw, **stat}
124+
params: Dict[str, Any] = {"_draw_idx": self._draw_idx, **draw, **stat}
125125
self._draw_idx += 1
126126
if not self._insert_query:
127127
names = ", ".join(params.keys())
@@ -186,9 +186,10 @@ def _get_rows( # pylint: disable=W0221
186186

187187
# The unpacking must also account for non-rigid shapes
188188
if is_rigid(nshape):
189+
assert nshape is not None
189190
buffer = numpy.empty((draws, *nshape), dtype)
190191
else:
191-
buffer = numpy.repeat(None, draws)
192+
buffer = numpy.array([None] * draws)
192193
for d, (vals,) in enumerate(data):
193194
buffer[d] = numpy.asarray(vals, dtype)
194195
return buffer
@@ -228,23 +229,21 @@ def __init__(
228229
self.created_at = created_at
229230
# We need handles on the chains to commit their batched inserts
230231
# before returning them to callers of `.get_chains()`.
231-
self._chains = None
232+
self._chains: List[ClickHouseChain] = []
232233
super().__init__(meta)
233234

234235
def init_chain(self, chain_number: int) -> ClickHouseChain:
235236
cmeta = ChainMeta(self.meta.rid, chain_number)
236237
create_chain_table(self._client, cmeta, self.meta)
237238
chain = ClickHouseChain(cmeta, self.meta, client=self._client_fn())
238-
if self._chains is None:
239-
self._chains = []
240239
self._chains.append(chain)
241240
return chain
242241

243-
def get_chains(self) -> Tuple[ClickHouseChain]:
242+
def get_chains(self) -> Tuple[ClickHouseChain, ...]:
244243
# Preferably return existing handles on chains that might have
245244
# uncommitted inserts pending.
246245
if self._chains:
247-
return self._chains
246+
return tuple(self._chains)
248247

249248
# Otherwise fetch existing chains from the DB.
250249
chains = []
@@ -274,13 +273,14 @@ def __init__(
274273
"""
275274
if client is None and client_fn is None:
276275
raise ValueError("Either a `client` or a `client_fn` must be provided.")
277-
self._client_fn = client_fn
278-
self._client = client
279276

280277
if client_fn is None:
281-
self._client_fn = lambda: client
278+
client_fn = lambda: client
282279
if client is None:
283-
self._client = self._client_fn()
280+
client = client_fn()
281+
282+
self._client_fn: Callable[[], clickhouse_driver.Client] = client_fn
283+
self._client: clickhouse_driver.Client = client
284284

285285
create_runs_table(self._client)
286286
super().__init__()

mcbackend/backends/numpy.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
This backend holds draws in memory, managing them via NumPy arrays.
33
"""
44
import math
5-
from typing import Dict, Optional, Sequence, Tuple
5+
from typing import Dict, List, Optional, Sequence, Tuple
66

77
import numpy
88

@@ -26,7 +26,7 @@ def grow_append(
2626
if rigid[vn]:
2727
extension = numpy.empty((ngrow,) + numpy.shape(v))
2828
else:
29-
extension = numpy.repeat(None, ngrow)
29+
extension = numpy.array([None] * ngrow)
3030
storage_dict[vn] = numpy.concatenate((target, extension), axis=0)
3131
target = storage_dict[vn]
3232
target[draw_idx] = v
@@ -53,10 +53,10 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
5353
where the correct amount of memory cannot be pre-allocated.
5454
In these cases, and when ``preallocate == 0`` object arrays are used.
5555
"""
56-
self._var_is_rigid = {}
57-
self._samples = {}
58-
self._stat_is_rigid = {}
59-
self._stats = {}
56+
self._var_is_rigid: Dict[str, bool] = {}
57+
self._samples: Dict[str, numpy.ndarray] = {}
58+
self._stat_is_rigid: Dict[str, bool] = {}
59+
self._stats: Dict[str, numpy.ndarray] = {}
6060
self._draw_idx = 0
6161

6262
# Create storage ndarrays for each model variable and sampler stat.
@@ -71,7 +71,7 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
7171
reserve = (preallocate, *var.shape)
7272
target_dict[var.name] = numpy.empty(reserve, var.dtype)
7373
else:
74-
target_dict[var.name] = numpy.repeat(None, preallocate)
74+
target_dict[var.name] = numpy.array([None] * preallocate)
7575

7676
super().__init__(cmeta, rmeta)
7777

@@ -105,7 +105,7 @@ class NumPyRun(Run):
105105

106106
def __init__(self, meta: RunMeta, *, preallocate: int) -> None:
107107
self._settings = dict(preallocate=preallocate)
108-
self._chains = []
108+
self._chains: List[NumPyChain] = []
109109
super().__init__(meta)
110110

111111
def init_chain(self, chain_number: int) -> NumPyChain:
@@ -114,7 +114,7 @@ def init_chain(self, chain_number: int) -> NumPyChain:
114114
self._chains.append(chain)
115115
return chain
116116

117-
def get_chains(self) -> Tuple[Chain]:
117+
def get_chains(self) -> Tuple[NumPyChain, ...]:
118118
return tuple(self._chains)
119119

120120

mcbackend/core.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,22 @@
33
"""
44
import collections
55
import logging
6-
from typing import Dict, Optional, Sequence, Sized, TypeVar
6+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Sized, TypeVar
77

88
import numpy
99

1010
from .meta import ChainMeta, RunMeta, Variable
1111
from .npproto.utils import ndarray_to_numpy
1212

13+
InferenceData = TypeVar("InferenceData")
1314
try:
14-
from arviz import InferenceData, from_dict
15+
from arviz import from_dict
16+
17+
if not TYPE_CHECKING:
18+
from arviz import InferenceData
19+
_HAS_ARVIZ = True
1520
except ModuleNotFoundError:
16-
InferenceData = TypeVar("InferenceData")
21+
_HAS_ARVIZ = False
1722

1823
Shape = Sequence[int]
1924
_log = logging.getLogger(__file__)
@@ -31,7 +36,9 @@ def is_rigid(nshape: Optional[Shape]):
3136
- ``[2, 0]`` indicates a matrix with 2 rows and dynamic number of columns (rigid: False).
3237
- ``None`` indicates dynamic dimensionality (rigid: False).
3338
"""
34-
if nshape is None or any(s == 0 for s in nshape):
39+
if nshape is None:
40+
return False
41+
if any(s == 0 for s in nshape):
3542
return False
3643
return True
3744

@@ -127,14 +134,14 @@ def coords(self) -> Dict[str, numpy.ndarray]:
127134
return {coord.name: ndarray_to_numpy(coord.values) for coord in self.meta.coordinates}
128135

129136
@property
130-
def dims(self) -> Dict[str, Sequence[str]]:
137+
def dims(self) -> Dict[str, List[str]]:
131138
dims = {}
132139
for var in self.meta.variables:
133140
if len(var.dims) == len(var.shape) and not var.undefined_ndim:
134-
dims[var.name] = var.dims
141+
dims[var.name] = list(var.dims)
135142
for dvar in self.meta.data:
136143
if len(dvar.dims) > 0:
137-
dims[dvar.name] = dvar.dims
144+
dims[dvar.name] = list(dvar.dims)
138145
return dims
139146

140147
@property
@@ -158,7 +165,7 @@ def to_inferencedata(self, **kwargs) -> InferenceData:
158165
idata : arviz.InferenceData
159166
Samples and metadata of this inference run.
160167
"""
161-
if isinstance(InferenceData, TypeVar):
168+
if not _HAS_ARVIZ:
162169
raise ModuleNotFoundError("ArviZ is not installed.")
163170

164171
variables = self.meta.variables

mcbackend/npproto/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def ndarray_from_numpy(arr: numpy.ndarray) -> Ndarray:
2222

2323

2424
def ndarray_to_numpy(nda: Ndarray) -> numpy.ndarray:
25+
arr: numpy.ndarray
2526
if "datetime64" in nda.dtype:
2627
# Backwards conversion: The data was stored as int64.
2728
arr = numpy.ndarray(

mcbackend/test_adapter_pymc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,21 @@ def wrapper(meta: RunMeta):
139139
ndarray_to_numpy(obs.value), simple_model["obs"].get_value()
140140
)
141141
pass
142+
143+
def test_uninitialized_exceptions(self, simple_model):
144+
backend = ClickHouseBackend(self._client)
145+
with simple_model:
146+
trace = TraceBackend(backend)
147+
148+
with pytest.raises(Exception, match="setup was not completed"):
149+
trace.get_values("scalar")
150+
151+
with pytest.raises(Exception, match="setup was not completed"):
152+
trace._get_stats("doesntmatter")
153+
154+
with pytest.raises(Exception, match="setup was not completed"):
155+
trace._get_sampler_stats("doesntmatter", 0, 0, 1)
156+
157+
with pytest.raises(Exception, match="setup was not completed"):
158+
trace.point(0)
159+
pass

mcbackend/test_npproto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ class TestUtils:
1616
numpy.array(5),
1717
numpy.array(["hello", "world"]),
1818
numpy.array([datetime(2020, 3, 4, 5, 6, 7, 8), datetime(2020, 3, 4, 5, 6, 7, 9)]),
19+
numpy.array(
20+
[datetime(2020, 3, 4, 5, 6, 7, 8), datetime(2020, 3, 4, 5, 6, 7, 9)],
21+
dtype="datetime64",
22+
),
1923
numpy.array([(1, 2), (3, 2, 1)], dtype=object),
2024
],
2125
)

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,9 @@ xfail_strict=true
33

44
[tool.black]
55
line-length = 100
6+
7+
[tool.mypy]
8+
exclude = [
9+
"^mcbackend/test_*",
10+
]
11+
ignore_missing_imports = true

0 commit comments

Comments
 (0)