Skip to content

Commit 76df402

Browse files
michaelosthegetwiecki
authored andcommitted
Add optional McBackend support
1 parent 6a0e74d commit 76df402

File tree

6 files changed

+638
-12
lines changed

6 files changed

+638
-12
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ jobs:
8585
tests/step_methods/hmc/test_quadpotential.py
8686
8787
- |
88+
tests/backends/test_mcbackend.py
8889
tests/distributions/test_truncated.py
8990
tests/logprob/test_abstract.py
9091
tests/logprob/test_censoring.py

pymc/backends/__init__.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,32 @@
6161
6262
"""
6363
from copy import copy
64-
from typing import Dict, List, Optional, Sequence, Union
64+
from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Union
6565

6666
import numpy as np
6767

68+
from typing_extensions import TypeAlias
69+
6870
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
6971
from pymc.backends.base import BaseTrace, IBaseTrace
7072
from pymc.backends.ndarray import NDArray
7173
from pymc.model import Model
7274
from pymc.step_methods.compound import BlockedStep, CompoundStep
7375

76+
HAS_MCB = False
77+
try:
78+
from mcbackend import Backend, Run
79+
80+
from pymc.backends.mcbackend import init_chain_adapters
81+
82+
TraceOrBackend = Union[BaseTrace, Backend]
83+
RunType: TypeAlias = Run
84+
HAS_MCB = True
85+
except ImportError:
86+
TraceOrBackend = BaseTrace # type: ignore
87+
RunType = type(None) # type: ignore
88+
89+
7490
__all__ = ["to_inference_data", "predictions_to_inference_data"]
7591

7692

@@ -99,16 +115,25 @@ def _init_trace(
99115

100116
def init_traces(
101117
*,
102-
backend: Optional[BaseTrace],
118+
backend: Optional[TraceOrBackend],
103119
chains: int,
104120
expected_length: int,
105121
step: Union[BlockedStep, CompoundStep],
106-
var_dtypes: Dict[str, np.dtype],
107-
var_shapes: Dict[str, Sequence[int]],
122+
initial_point: Mapping[str, np.ndarray],
108123
model: Model,
109-
) -> Sequence[IBaseTrace]:
124+
) -> Tuple[Optional[RunType], Sequence[IBaseTrace]]:
110125
"""Initializes a trace recorder for each chain."""
111-
return [
126+
if HAS_MCB and isinstance(backend, Backend):
127+
return init_chain_adapters(
128+
backend=backend,
129+
chains=chains,
130+
initial_point=initial_point,
131+
step=step,
132+
model=model,
133+
)
134+
135+
assert backend is None or isinstance(backend, BaseTrace)
136+
traces = [
112137
_init_trace(
113138
expected_length=expected_length,
114139
stats_dtypes=step.stats_dtypes,
@@ -118,3 +143,4 @@ def init_traces(
118143
)
119144
for chain_number in range(chains)
120145
]
146+
return None, traces

pymc/backends/mcbackend.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import base64
16+
import logging
17+
import pickle
18+
19+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
20+
21+
import hagelkorn
22+
import mcbackend as mcb
23+
import numpy as np
24+
25+
from mcbackend.npproto.utils import ndarray_from_numpy
26+
from pytensor.compile.sharedvalue import SharedVariable
27+
from pytensor.graph.basic import Constant
28+
29+
from pymc.backends.base import IBaseTrace
30+
from pymc.model import Model
31+
from pymc.pytensorf import PointFunc
32+
from pymc.step_methods.compound import (
33+
BlockedStep,
34+
CompoundStep,
35+
StatsBijection,
36+
flat_statname,
37+
flatten_steps,
38+
)
39+
40+
_log = logging.getLogger("pymc")
41+
42+
43+
def find_data(pmodel: Model) -> List[mcb.DataVariable]:
44+
"""Extracts data variables from a model."""
45+
observed_rvs = {pmodel.rvs_to_values[rv] for rv in pmodel.observed_RVs}
46+
dvars = []
47+
# All data containers are named vars!
48+
for name, var in pmodel.named_vars.items():
49+
dv = mcb.DataVariable(name)
50+
if isinstance(var, Constant):
51+
dv.value = ndarray_from_numpy(var.data)
52+
elif isinstance(var, SharedVariable):
53+
dv.value = ndarray_from_numpy(var.get_value())
54+
else:
55+
continue
56+
dv.dims = list(pmodel.named_vars_to_dims.get(name, []))
57+
dv.is_observed = var in observed_rvs
58+
dvars.append(dv)
59+
return dvars
60+
61+
62+
def get_variables_and_point_fn(
63+
model: Model, initial_point: Mapping[str, np.ndarray]
64+
) -> Tuple[List[mcb.Variable], PointFunc]:
65+
"""Get metadata on free, value and deterministic model variables."""
66+
# The samplers act only on the inputs needed for the log-likelihood,
67+
# but the user is interested in transformed variables and deterministics.
68+
vvars = model.value_vars
69+
vars = model.unobserved_value_vars
70+
# Below we compilt the "point function" that transforms a draw to the set
71+
# of untransformed, transformed and deterministic variables that will be traced.
72+
point_fn = model.compile_fn(vars, inputs=vvars, on_unused_input="ignore", point_fn=True)
73+
point_fn = cast(PointFunc, point_fn)
74+
point = point_fn(initial_point)
75+
76+
names = [v.name for v in vars]
77+
dtypes = [v.dtype for v in vars]
78+
shapes = [v.shape for v in point]
79+
deterministics = {d.name for d in model.deterministics}
80+
variables = [
81+
mcb.Variable(
82+
name=name,
83+
dtype=str(dtype),
84+
shape=list(shape),
85+
dims=list(model.named_vars_to_dims.get(name, [])),
86+
is_deterministic=name in deterministics,
87+
)
88+
for name, dtype, shape in zip(names, dtypes, shapes)
89+
]
90+
return variables, point_fn
91+
92+
93+
class ChainRecordAdapter(IBaseTrace):
94+
"""Wraps an McBackend ``Chain`` as an ``IBaseTrace``."""
95+
96+
def __init__(
97+
self, chain: mcb.Chain, point_fn: PointFunc, stats_bijection: StatsBijection
98+
) -> None:
99+
# Assign attributes required by IBaseTrace
100+
self.chain = chain.cmeta.chain_number
101+
self.varnames = [v.name for v in chain.rmeta.variables]
102+
stats_dtypes = {s.name: np.dtype(s.dtype) for s in chain.rmeta.sample_stats}
103+
self.sampler_vars = [
104+
{sname: stats_dtypes[fname] for fname, sname, is_obj in sstats}
105+
for sstats in stats_bijection._stat_groups
106+
]
107+
108+
self._chain = chain
109+
self._point_fn = point_fn
110+
self._statsbj = stats_bijection
111+
super().__init__()
112+
113+
def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
114+
values = self._point_fn(draw)
115+
value_dict = {n: v for n, v in zip(self.varnames, values)}
116+
stats_dict = self._statsbj.map(stats)
117+
# Apply pickling to objects stats
118+
for fname in self._statsbj.object_stats.keys():
119+
val_bytes = pickle.dumps(stats_dict[fname])
120+
val = base64.encodebytes(val_bytes).decode("ascii")
121+
stats_dict[fname] = np.array(val, dtype=str)
122+
return self._chain.append(value_dict, stats_dict)
123+
124+
def __len__(self):
125+
return len(self._chain)
126+
127+
def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
128+
return self._chain.get_draws(varname, slice(burn, None, thin))
129+
130+
def _get_stats(self, fname: str, slc: slice) -> np.ndarray:
131+
"""Wraps `self._chain.get_stats` but unpickles automatically."""
132+
values = self._chain.get_stats(fname, slc)
133+
# Unpickle object stats
134+
if fname in self._statsbj.object_stats:
135+
objs = []
136+
for v in values:
137+
enc = str(v).encode("ascii")
138+
str_ = base64.decodebytes(enc)
139+
obj = pickle.loads(str_)
140+
objs.append(obj)
141+
return np.array(objs, dtype=object)
142+
return values
143+
144+
def get_sampler_stats(
145+
self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1
146+
) -> np.ndarray:
147+
slc = slice(burn, None, thin)
148+
# When there's just one sampler, default to remove the sampler dimension
149+
if sampler_idx is None and self._statsbj.n_samplers == 1:
150+
sampler_idx = 0
151+
# Fetching for a specific sampler is easy
152+
if sampler_idx is not None:
153+
return self._get_stats(flat_statname(sampler_idx, stat_name), slc)
154+
# To fetch for all samplers, we must collect the arrays one by one.
155+
stats_dict = {
156+
stat.name: self._get_stats(stat.name, slc)
157+
for stat in self._chain.rmeta.sample_stats
158+
if stat_name in stat.name
159+
}
160+
if not stats_dict:
161+
raise KeyError(f"No stat '{stat_name}' was recorded.")
162+
stats_list = self._statsbj.rmap(stats_dict)
163+
stats_arrays = []
164+
is_ragged = False
165+
for sd in stats_list:
166+
if not sd:
167+
is_ragged = True
168+
continue
169+
else:
170+
stats_arrays.append(tuple(sd.values())[0])
171+
172+
if is_ragged:
173+
_log.debug("Stat '%s' was not recorded by all samplers.", stat_name)
174+
if len(stats_arrays) == 1:
175+
return stats_arrays[0]
176+
return np.array(stats_arrays).T
177+
178+
def _slice(self, idx: slice) -> "IBaseTrace":
179+
# Get the integer indices
180+
start, stop, step = idx.indices(len(self))
181+
indices = np.arange(start, stop, step)
182+
183+
# Create a NumPyChain for the sliced data
184+
nchain = mcb.backends.numpy.NumPyChain(
185+
self._chain.cmeta, self._chain.rmeta, preallocate=len(indices)
186+
)
187+
188+
# Copy at selected indices and append them to the new chain.
189+
# This may be slow, but NumPyChain currently don't have a batch-insert or slice API.
190+
vnames = [v.name for v in nchain.variables.values()]
191+
snames = [s.name for s in nchain.sample_stats.values()]
192+
for i in indices:
193+
draw = self._chain.get_draws_at(i, var_names=vnames)
194+
stats = self._chain.get_stats_at(i, stat_names=snames)
195+
nchain.append(draw, stats)
196+
return ChainRecordAdapter(nchain, self._point_fn, self._statsbj)
197+
198+
def point(self, idx: int) -> Dict[str, np.ndarray]:
199+
return self._chain.get_draws_at(idx, [v.name for v in self._chain.variables.values()])
200+
201+
202+
def make_runmeta_and_point_fn(
203+
*,
204+
initial_point: Mapping[str, np.ndarray],
205+
step: Union[CompoundStep, BlockedStep],
206+
model: Model,
207+
) -> Tuple[mcb.RunMeta, PointFunc]:
208+
variables, point_fn = get_variables_and_point_fn(model, initial_point)
209+
210+
sample_stats = [
211+
mcb.Variable("tune", "bool"),
212+
]
213+
214+
# In PyMC the sampler stats are grouped by the sampler.
215+
steps = flatten_steps(step)
216+
for s, sm in enumerate(steps):
217+
for statname, (dtype, shape) in sm.stats_dtypes_shapes.items():
218+
sname = flat_statname(s, statname)
219+
sshape = [
220+
# PyMC uses None to indicate dynamic dims, MCB uses -1
221+
(-1 if s is None else s)
222+
for s in (shape or [])
223+
]
224+
svar = mcb.Variable(
225+
name=sname,
226+
dtype=np.dtype(dtype).name,
227+
shape=sshape,
228+
undefined_ndim=shape is None,
229+
)
230+
sample_stats.append(svar)
231+
232+
coordinates = [
233+
mcb.Coordinate(dname, mcb.npproto.utils.ndarray_from_numpy(np.array(cvals)))
234+
for dname, cvals in model.coords.items()
235+
if cvals is not None
236+
]
237+
meta = mcb.RunMeta(
238+
rid=hagelkorn.random(),
239+
variables=variables,
240+
coordinates=coordinates,
241+
sample_stats=sample_stats,
242+
data=find_data(model),
243+
)
244+
return meta, point_fn
245+
246+
247+
def init_chain_adapters(
248+
*,
249+
backend: mcb.Backend,
250+
chains: int,
251+
initial_point: Mapping[str, np.ndarray],
252+
step: Union[CompoundStep, BlockedStep],
253+
model: Model,
254+
) -> Tuple[mcb.Run, List[ChainRecordAdapter]]:
255+
"""Create an McBackend metadata description for the MCMC run.
256+
257+
Parameters
258+
----------
259+
backend
260+
An McBackend `Backend` instance.
261+
chains
262+
Number of chains to initialize.
263+
initial_point
264+
Dictionary mapping value variable names to initial values.
265+
step : CompoundStep or BlockedStep
266+
The step method that iterates the MCMC.
267+
model : pm.Model
268+
The current PyMC model.
269+
270+
Returns
271+
-------
272+
adapters
273+
Chain recording adapters that wrap McBackend Chains in the PyMC IBaseTrace interface.
274+
"""
275+
meta, point_fn = make_runmeta_and_point_fn(initial_point=initial_point, step=step, model=model)
276+
run = backend.init_run(meta)
277+
statsbj = StatsBijection(step.stats_dtypes)
278+
adapters = [
279+
ChainRecordAdapter(
280+
chain=run.init_chain(chain_number=chain_number),
281+
point_fn=point_fn,
282+
stats_bijection=statsbj,
283+
)
284+
for chain_number in range(chains)
285+
]
286+
return run, adapters

0 commit comments

Comments
 (0)