|
4 | 4 | import warnings
|
5 | 5 |
|
6 | 6 | from functools import partial
|
7 |
| -from typing import Callable, Dict, List, Optional, Sequence, Union |
| 7 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Union |
8 | 8 |
|
9 | 9 | from pymc.initial_point import StartDict
|
10 | 10 | from pymc.sampling import RandomSeed, _get_seeds_per_chain, _init_jitter
|
@@ -209,60 +209,69 @@ def one_step(state, rng_key):
|
209 | 209 |
|
210 | 210 |
|
211 | 211 | def sample_blackjax_nuts(
|
212 |
| - draws=1000, |
213 |
| - tune=1000, |
214 |
| - chains=4, |
215 |
| - target_accept=0.8, |
216 |
| - random_seed: RandomSeed = None, |
217 |
| - initvals=None, |
218 |
| - model=None, |
219 |
| - var_names=None, |
220 |
| - keep_untransformed=False, |
221 |
| - chain_method="parallel", |
222 |
| - postprocessing_backend=None, |
223 |
| - idata_kwargs=None, |
224 |
| -): |
| 212 | + draws: int = 1000, |
| 213 | + tune: int = 1000, |
| 214 | + chains: int = 4, |
| 215 | + target_accept: float = 0.8, |
| 216 | + random_seed: Optional[RandomSeed] = None, |
| 217 | + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, |
| 218 | + model: Optional[Model] = None, |
| 219 | + var_names: Optional[Sequence[str]] = None, |
| 220 | + keep_untransformed: bool = False, |
| 221 | + chain_method: str = "parallel", |
| 222 | + postprocessing_backend: Optional[str] = None, |
| 223 | + idata_kwargs: Optional[Dict[str, Any]] = None, |
| 224 | +) -> az.InferenceData: |
225 | 225 | """
|
226 | 226 | Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
|
227 | 227 |
|
228 | 228 | Parameters
|
229 | 229 | ----------
|
230 | 230 | draws : int, default 1000
|
231 |
| - The number of samples to draw. The number of tuned samples are discarded by default. |
| 231 | + The number of samples to draw. The number of tuned samples are discarded by |
| 232 | + default. |
232 | 233 | tune : int, default 1000
|
233 | 234 | Number of iterations to tune. Samplers adjust the step sizes, scalings or
|
234 |
| - similar during tuning. Tuning samples will be drawn in addition to the number specified in |
235 |
| - the ``draws`` argument. |
| 235 | + similar during tuning. Tuning samples will be drawn in addition to the number |
| 236 | + specified in the ``draws`` argument. |
236 | 237 | chains : int, default 4
|
237 | 238 | The number of chains to sample.
|
238 | 239 | target_accept : float in [0, 1].
|
239 |
| - The step size is tuned such that we approximate this acceptance rate. Higher values like |
240 |
| - 0.9 or 0.95 often work better for problematic posteriors. |
| 240 | + The step size is tuned such that we approximate this acceptance rate. Higher |
| 241 | + values like 0.9 or 0.95 often work better for problematic posteriors. |
241 | 242 | random_seed : int, RandomState or Generator, optional
|
242 | 243 | Random seed used by the sampling steps.
|
| 244 | + initvals: StartDict or Sequence[Optional[StartDict]], optional |
| 245 | + Initial values for random variables provided as a dictionary (or sequence of |
| 246 | + dictionaries) mapping the random variable (by name or reference) to desired |
| 247 | + starting values. |
243 | 248 | model : Model, optional
|
244 |
| - Model to sample from. The model needs to have free random variables. When inside a ``with`` model |
245 |
| - context, it defaults to that model, otherwise the model must be passed explicitly. |
246 |
| - var_names : iterable of str, optional |
247 |
| - Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior |
| 249 | + Model to sample from. The model needs to have free random variables. When inside |
| 250 | + a ``with`` model context, it defaults to that model, otherwise the model must be |
| 251 | + passed explicitly. |
| 252 | + var_names : sequence of str, optional |
| 253 | + Names of variables for which to compute the posterior samples. Defaults to all |
| 254 | + variables in the posterior. |
248 | 255 | keep_untransformed : bool, default False
|
249 | 256 | Include untransformed variables in the posterior samples. Defaults to False.
|
250 | 257 | chain_method : str, default "parallel"
|
251 |
| - Specify how samples should be drawn. The choices include "parallel", and "vectorized". |
| 258 | + Specify how samples should be drawn. The choices include "parallel", and |
| 259 | + "vectorized". |
252 | 260 | postprocessing_backend : str, optional
|
253 | 261 | Specify how postprocessing should be computed. gpu or cpu
|
254 | 262 | idata_kwargs : dict, optional
|
255 |
| - Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value |
256 |
| - for the ``log_likelihood`` key to indicate that the pointwise log likelihood should |
257 |
| - not be included in the returned object. Values for ``observed_data``, ``constant_data``, |
258 |
| - ``coords``, and ``dims`` are inferred from the ``model`` argument if not provided |
259 |
| - in ``idata_kwargs``. |
| 263 | + Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as |
| 264 | + value for the ``log_likelihood`` key to indicate that the pointwise log |
| 265 | + likelihood should not be included in the returned object. Values for |
| 266 | + ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from |
| 267 | + the ``model`` argument if not provided in ``idata_kwargs``. |
260 | 268 |
|
261 | 269 | Returns
|
262 | 270 | -------
|
263 | 271 | InferenceData
|
264 |
| - ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and |
265 |
| - pointwise log likeihood values (unless skipped with ``idata_kwargs``). |
| 272 | + ArviZ ``InferenceData`` object that contains the posterior samples, together |
| 273 | + with their respective sample stats and pointwise log likeihood values (unless |
| 274 | + skipped with ``idata_kwargs``). |
266 | 275 | """
|
267 | 276 | import blackjax
|
268 | 277 |
|
|
0 commit comments