Skip to content

Commit e739546

Browse files
Introduce IBaseTrace with minimal API for chain backends
1 parent 51724c5 commit e739546

File tree

3 files changed

+92
-56
lines changed

3 files changed

+92
-56
lines changed

pymc/backends/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@
6565

6666
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
6767
from pymc.backends.base import BaseTrace
68-
from pymc.backends.ndarray import NDArray, point_list_to_multitrace
68+
from pymc.backends.ndarray import NDArray
69+
from pymc.model import Model
6970

7071
__all__ = ["to_inference_data", "predictions_to_inference_data"]
7172

@@ -76,7 +77,7 @@ def _init_trace(
7677
chain_number: int,
7778
stats_dtypes: List[Dict[str, type]],
7879
trace: Optional[BaseTrace],
79-
model,
80+
model: Model,
8081
) -> BaseTrace:
8182
"""Initializes a trace backend for a chain."""
8283
strace: BaseTrace

pymc/backends/base.py

Lines changed: 87 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222

2323
from abc import ABC
2424
from typing import (
25+
Any,
2526
Dict,
2627
List,
28+
Mapping,
2729
Optional,
2830
Sequence,
2931
Set,
@@ -47,7 +49,87 @@ class BackendError(Exception):
4749
pass
4850

4951

50-
class BaseTrace(ABC):
52+
class IBaseTrace(ABC, Sized):
53+
"""Minimal interface needed to record and access draws and stats for one MCMC chain."""
54+
55+
chain: int
56+
"""Chain number."""
57+
58+
varnames: List[str]
59+
"""Names of tracked variables."""
60+
61+
sampler_vars: List[Dict[str, type]]
62+
"""Sampler stats for each sampler."""
63+
64+
def __len__(self):
65+
raise NotImplementedError()
66+
67+
def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
68+
"""Get values from trace.
69+
70+
Parameters
71+
----------
72+
varname: str
73+
burn: int
74+
thin: int
75+
76+
Returns
77+
-------
78+
A NumPy array
79+
"""
80+
raise NotImplementedError()
81+
82+
def get_sampler_stats(self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1):
83+
"""Get sampler statistics from the trace.
84+
85+
Parameters
86+
----------
87+
stat_name: str
88+
sampler_idx: int or None
89+
burn: int
90+
thin: int
91+
92+
Returns
93+
-------
94+
If the `sampler_idx` is specified, return the statistic with
95+
the given name in a numpy array. If it is not specified and there
96+
is more than one sampler that provides this statistic, return
97+
a numpy array of shape (m, n), where `m` is the number of
98+
such samplers, and `n` is the number of samples.
99+
"""
100+
raise NotImplementedError()
101+
102+
def _slice(self, idx: slice) -> "IBaseTrace":
103+
"""Slice trace object."""
104+
raise NotImplementedError()
105+
106+
def point(self, idx: int) -> Dict[str, np.ndarray]:
107+
"""Return dictionary of point values at `idx` for current chain
108+
with variables names as keys.
109+
"""
110+
raise NotImplementedError()
111+
112+
def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
113+
"""Record results of a sampling iteration.
114+
115+
Parameters
116+
----------
117+
draw: dict
118+
Values mapped to variable names
119+
stats: list of dicts
120+
The diagnostic values for each sampler
121+
"""
122+
raise NotImplementedError()
123+
124+
def close(self):
125+
"""Close the backend.
126+
127+
This is called after sampling has finished.
128+
"""
129+
pass
130+
131+
132+
class BaseTrace(IBaseTrace):
51133
"""Base trace object
52134
53135
Parameters
@@ -127,25 +209,6 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
127209
self._set_sampler_vars(sampler_vars)
128210
self._is_base_setup = True
129211

130-
def record(self, point, sampler_states=None):
131-
"""Record results of a sampling iteration.
132-
133-
Parameters
134-
----------
135-
point: dict
136-
Values mapped to variable names
137-
sampler_states: list of dicts
138-
The diagnostic values for each sampler
139-
"""
140-
raise NotImplementedError
141-
142-
def close(self):
143-
"""Close the database backend.
144-
145-
This is called after sampling has finished.
146-
"""
147-
pass
148-
149212
# Selection methods
150213

151214
def __getitem__(self, idx):
@@ -157,24 +220,6 @@ def __getitem__(self, idx):
157220
except (ValueError, TypeError): # Passed variable or variable name.
158221
raise ValueError("Can only index with slice or integer")
159222

160-
def __len__(self):
161-
raise NotImplementedError
162-
163-
def get_values(self, varname, burn=0, thin=1):
164-
"""Get values from trace.
165-
166-
Parameters
167-
----------
168-
varname: str
169-
burn: int
170-
thin: int
171-
172-
Returns
173-
-------
174-
A NumPy array
175-
"""
176-
raise NotImplementedError
177-
178223
def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
179224
"""Get sampler statistics from the trace.
180225
@@ -220,19 +265,9 @@ def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
220265
"""Get sampler statistics."""
221266
raise NotImplementedError()
222267

223-
def _slice(self, idx: Union[int, slice]):
224-
"""Slice trace object."""
225-
raise NotImplementedError()
226-
227-
def point(self, idx: int) -> Dict[str, np.ndarray]:
228-
"""Return dictionary of point values at `idx` for current chain
229-
with variables names as keys.
230-
"""
231-
raise NotImplementedError()
232-
233268
@property
234269
def stat_names(self) -> Set[str]:
235-
names = set()
270+
names: Set[str] = set()
236271
for vars in self.sampler_vars or []:
237272
names.update(vars.keys())
238273

@@ -290,7 +325,7 @@ class MultiTrace:
290325
List of variable names in the trace(s)
291326
"""
292327

293-
def __init__(self, straces: Sequence[BaseTrace]):
328+
def __init__(self, straces: Sequence[IBaseTrace]):
294329
if len({t.chain for t in straces}) != len(straces):
295330
raise ValueError("Chains are not unique.")
296331
self._straces = {t.chain: t for t in straces}
@@ -386,7 +421,7 @@ def stat_names(self) -> Set[str]:
386421
sampler_vars = [s.sampler_vars for s in self._straces.values()]
387422
if not all(svars == sampler_vars[0] for svars in sampler_vars):
388423
raise ValueError("Inividual chains contain different sampler stats")
389-
names = set()
424+
names: Set[str] = set()
390425
for trace in self._straces.values():
391426
if trace.sampler_vars is None:
392427
continue
@@ -472,7 +507,7 @@ def get_sampler_stats(
472507
]
473508
return _squeeze_cat(results, combine, squeeze)
474509

475-
def _slice(self, slice):
510+
def _slice(self, slice: slice):
476511
"""Return a new MultiTrace object sliced according to `slice`."""
477512
new_traces = [trace._slice(slice) for trace in self._straces.values()]
478513
trace = MultiTrace(new_traces)

pymc/backends/ndarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
156156
"""
157157
return self.samples[varname][burn::thin]
158158

159-
def _slice(self, idx):
159+
def _slice(self, idx: slice):
160160
# Slicing directly instead of using _slice_as_ndarray to
161161
# support stop value in slice (which is needed by
162162
# iter_sample).
@@ -174,7 +174,7 @@ def _slice(self, idx):
174174
return sliced
175175
sliced._stats = []
176176
for vars in self._stats:
177-
var_sliced = {}
177+
var_sliced: Dict[str, np.ndarray] = {}
178178
sliced._stats.append(var_sliced)
179179
for key, vals in vars.items():
180180
var_sliced[key] = vals[idx]

0 commit comments

Comments
 (0)