Skip to content

Commit ad3abd9

Browse files
Update docstring
1 parent f9b6258 commit ad3abd9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pymc_experimental/inference/jax_find_map.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def fit_laplace(
162162
)
163163

164164
f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph(
165-
logp,
165+
cast(TensorVariable, logp),
166166
use_grad=True,
167167
use_hess=True,
168168
use_hessp=False,
@@ -376,8 +376,8 @@ def find_MAP(
376376
Seed for the random number generator or a numpy Generator for reproducibility
377377
return_raw: bool | False, optinal
378378
Whether to also return the full output of `scipy.optimize.minimize`
379-
jitter : bool, optional
380-
Whether to add jitter to the initial values. Defaults to False.
379+
jitter_rvs : list of TensorVariables, optional
380+
Variables whose initial values should be jittered. If None, all variables are jittered.
381381
progressbar : bool, optional
382382
Whether to display a progress bar during optimization. Defaults to True.
383383
include_transformed: bool, optional

0 commit comments

Comments
 (0)