|
| 1 | +""" |
| 2 | +This backend simply discards draws. There are not stored in memory. |
| 3 | +This can be used in situations where we want to run an MCMC but not permanently |
| 4 | +store its output. |
| 5 | +""" |
| 6 | + |
| 7 | +# Code-wise, a NullChain is essentially just a NumpyChain without the underlying data array. |
| 8 | + |
| 9 | +from typing import Dict, List, Mapping, Optional, Sequence, Tuple |
| 10 | + |
| 11 | +import numpy |
| 12 | + |
| 13 | +from ..core import Backend, Chain, Run |
| 14 | +from ..meta import ChainMeta, RunMeta |
| 15 | +from .numpy import grow_append, prepare_storage |
| 16 | + |
| 17 | + |
| 18 | +class NullChain(Chain): |
| 19 | + """A null storage: discards values immediately and allocates no memory. |
| 20 | +
|
| 21 | + Use cases are |
| 22 | +
|
| 23 | + - Online computations: Draws are used and discarded immediately, allowing for much larger sample spaces. |
| 24 | + - Profiling: To use as a baseline, to measure compute time & memory before allocating memory for draws. |
| 25 | + Comparing with another backend would then show how much overhead it adds. |
| 26 | +
|
| 27 | + Since draws are not stored, only a subset of the `Chain` interface is supported: |
| 28 | +
|
| 29 | + - Supported: `__len__`, `append`, `get_stats`, `get_stats_at` |
| 30 | + - Not supported: `get_draws`, `get_draws_at` |
| 31 | +
|
| 32 | + .. Todo:: Option to also sampling stats? |
| 33 | + .. Todo:: Allow retrieving the most recent draw? |
| 34 | +
|
| 35 | + """ |
| 36 | + |
| 37 | + def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> None: |
| 38 | + """Creates a null storage for draws from a chain: will gobble outputs without storing them |
| 39 | +
|
| 40 | + Parameters |
| 41 | + ---------- |
| 42 | + cmeta : ChainMeta |
| 43 | + Metadata of the chain. |
| 44 | + rmeta : RunMeta |
| 45 | + Metadata of the MCMC run. |
| 46 | + preallocate : int |
| 47 | + Influences the memory pre-allocation behavior. |
| 48 | + (Draws are not saved, but stats may still be.) |
| 49 | + The default is to reserve memory for ``preallocate`` draws |
| 50 | + and grow the allocated memory by 10 % when needed. |
| 51 | + Exceptions are variables with non-rigid shapes (indicated by 0 in the shape tuple) |
| 52 | + where the correct amount of memory cannot be pre-allocated. |
| 53 | + In these cases object arrays are used. |
| 54 | + """ |
| 55 | + self._draw_idx = 0 |
| 56 | + |
| 57 | + # Create storage ndarrays only for sampler stats. |
| 58 | + self._stats, self._stat_is_rigid = prepare_storage(rmeta.sample_stats, preallocate) |
| 59 | + |
| 60 | + super().__init__(cmeta, rmeta) |
| 61 | + |
| 62 | + def append( # pylint: disable=duplicate-code |
| 63 | + self, draw: Mapping[str, numpy.ndarray], stats: Optional[Mapping[str, numpy.ndarray]] = None |
| 64 | + ): |
| 65 | + if stats: |
| 66 | + grow_append(self._stats, stats, self._stat_is_rigid, self._draw_idx) |
| 67 | + self._draw_idx += 1 |
| 68 | + return |
| 69 | + |
| 70 | + def __len__(self) -> int: |
| 71 | + return self._draw_idx |
| 72 | + |
| 73 | + def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray: |
| 74 | + raise RuntimeError("NullChain does not save draws.") |
| 75 | + |
| 76 | + def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]: |
| 77 | + raise RuntimeError("NullChain does not save draws.") |
| 78 | + |
| 79 | + def get_stats( # pylint: disable=duplicate-code |
| 80 | + self, stat_name: str, slc: slice = slice(None) |
| 81 | + ) -> numpy.ndarray: |
| 82 | + data = self._stats[stat_name][: self._draw_idx][slc] |
| 83 | + if self.sample_stats[stat_name].dtype == "str": |
| 84 | + return numpy.array(data.tolist(), dtype=str) |
| 85 | + return data |
| 86 | + |
| 87 | + def get_stats_at(self, idx: int, stat_names: Sequence[str]) -> Dict[str, numpy.ndarray]: |
| 88 | + return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names} |
| 89 | + |
| 90 | + |
| 91 | +class NullRun(Run): |
| 92 | + """An MCMC run where samples are immediately discarded.""" |
| 93 | + |
| 94 | + def __init__(self, meta: RunMeta, *, preallocate: int) -> None: |
| 95 | + self._settings = {"preallocate": preallocate} |
| 96 | + self._chains: List[NullChain] = [] |
| 97 | + super().__init__(meta) |
| 98 | + |
| 99 | + def init_chain(self, chain_number: int) -> NullChain: |
| 100 | + cmeta = ChainMeta(self.meta.rid, chain_number) |
| 101 | + chain = NullChain(cmeta, self.meta, **self._settings) |
| 102 | + self._chains.append(chain) |
| 103 | + return chain |
| 104 | + |
| 105 | + def get_chains(self) -> Tuple[NullChain, ...]: |
| 106 | + return tuple(self._chains) |
| 107 | + |
| 108 | + |
| 109 | +class NullBackend(Backend): |
| 110 | + """A backend which discards samples immediately.""" |
| 111 | + |
| 112 | + def __init__(self, preallocate: int = 1_000) -> None: |
| 113 | + self._settings = {"preallocate": preallocate} |
| 114 | + super().__init__() |
| 115 | + |
| 116 | + def init_run(self, meta: RunMeta) -> NullRun: |
| 117 | + return NullRun(meta, **self._settings) |
0 commit comments