Saving deterministic values using BlackJax #792
ASKabalan
started this conversation in
New features
Replies: 1 comment
-
I think this is not possible with blackjax it would be very nice if we get control the has_aux parameter in here File ~/micromamba/envs/jax/lib/python3.10/site-packages/blackjax/mcmc/hmc.py:89, in init(position, logdensity_fn)
88 def init(position: ArrayLikeTree, logdensity_fn: Callable):
---> 89 logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
90 return HMCState(position, logdensity, logdensity_grad) So we can return intermediate values in the second position of logprob function |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Helllo 👋
Thank you very much for this nice package.
I was wondering if there is an equivalent to
numpyro.deterministic
Where I can log intermedaite values from my logprob
Thank you
Beta Was this translation helpful? Give feedback.
All reactions