Skip to content

Commit 9f93b3e

Browse files
jhrcookricardoV94
andauthored
Typehints and updated docstring for Blackjax NUTS sampling function (#6022)
* refactor: typehints for arguments and return of `sample_blackjax_nuts` * doc: style and add `initvals` to `sample_blackjax_nuts` docstring * refactor: change `var_names` from `Iterable` to `Sequence` typehint Co-authored-by: Ricardo Vieira <[email protected]> * style: remove Iterable import Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 18bbcbb commit 9f93b3e

File tree

1 file changed

+40
-31
lines changed

1 file changed

+40
-31
lines changed

pymc/sampling_jax.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55

66
from functools import partial
7-
from typing import Callable, Dict, List, Optional, Sequence, Union
7+
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
88

99
from pymc.initial_point import StartDict
1010
from pymc.sampling import RandomSeed, _get_seeds_per_chain, _init_jitter
@@ -209,60 +209,69 @@ def one_step(state, rng_key):
209209

210210

211211
def sample_blackjax_nuts(
212-
draws=1000,
213-
tune=1000,
214-
chains=4,
215-
target_accept=0.8,
216-
random_seed: RandomSeed = None,
217-
initvals=None,
218-
model=None,
219-
var_names=None,
220-
keep_untransformed=False,
221-
chain_method="parallel",
222-
postprocessing_backend=None,
223-
idata_kwargs=None,
224-
):
212+
draws: int = 1000,
213+
tune: int = 1000,
214+
chains: int = 4,
215+
target_accept: float = 0.8,
216+
random_seed: Optional[RandomSeed] = None,
217+
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
218+
model: Optional[Model] = None,
219+
var_names: Optional[Sequence[str]] = None,
220+
keep_untransformed: bool = False,
221+
chain_method: str = "parallel",
222+
postprocessing_backend: Optional[str] = None,
223+
idata_kwargs: Optional[Dict[str, Any]] = None,
224+
) -> az.InferenceData:
225225
"""
226226
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
227227
228228
Parameters
229229
----------
230230
draws : int, default 1000
231-
The number of samples to draw. The number of tuned samples are discarded by default.
231+
The number of samples to draw. The number of tuned samples are discarded by
232+
default.
232233
tune : int, default 1000
233234
Number of iterations to tune. Samplers adjust the step sizes, scalings or
234-
similar during tuning. Tuning samples will be drawn in addition to the number specified in
235-
the ``draws`` argument.
235+
similar during tuning. Tuning samples will be drawn in addition to the number
236+
specified in the ``draws`` argument.
236237
chains : int, default 4
237238
The number of chains to sample.
238239
target_accept : float in [0, 1].
239-
The step size is tuned such that we approximate this acceptance rate. Higher values like
240-
0.9 or 0.95 often work better for problematic posteriors.
240+
The step size is tuned such that we approximate this acceptance rate. Higher
241+
values like 0.9 or 0.95 often work better for problematic posteriors.
241242
random_seed : int, RandomState or Generator, optional
242243
Random seed used by the sampling steps.
244+
initvals: StartDict or Sequence[Optional[StartDict]], optional
245+
Initial values for random variables provided as a dictionary (or sequence of
246+
dictionaries) mapping the random variable (by name or reference) to desired
247+
starting values.
243248
model : Model, optional
244-
Model to sample from. The model needs to have free random variables. When inside a ``with`` model
245-
context, it defaults to that model, otherwise the model must be passed explicitly.
246-
var_names : iterable of str, optional
247-
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
249+
Model to sample from. The model needs to have free random variables. When inside
250+
a ``with`` model context, it defaults to that model, otherwise the model must be
251+
passed explicitly.
252+
var_names : sequence of str, optional
253+
Names of variables for which to compute the posterior samples. Defaults to all
254+
variables in the posterior.
248255
keep_untransformed : bool, default False
249256
Include untransformed variables in the posterior samples. Defaults to False.
250257
chain_method : str, default "parallel"
251-
Specify how samples should be drawn. The choices include "parallel", and "vectorized".
258+
Specify how samples should be drawn. The choices include "parallel", and
259+
"vectorized".
252260
postprocessing_backend : str, optional
253261
Specify how postprocessing should be computed. gpu or cpu
254262
idata_kwargs : dict, optional
255-
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
256-
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
257-
not be included in the returned object. Values for ``observed_data``, ``constant_data``,
258-
``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
259-
in ``idata_kwargs``.
263+
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
264+
value for the ``log_likelihood`` key to indicate that the pointwise log
265+
likelihood should not be included in the returned object. Values for
266+
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
267+
the ``model`` argument if not provided in ``idata_kwargs``.
260268
261269
Returns
262270
-------
263271
InferenceData
264-
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
265-
pointwise log likeihood values (unless skipped with ``idata_kwargs``).
272+
ArviZ ``InferenceData`` object that contains the posterior samples, together
273+
with their respective sample stats and pointwise log likeihood values (unless
274+
skipped with ``idata_kwargs``).
266275
"""
267276
import blackjax
268277

0 commit comments

Comments
 (0)