An exercice both native JaxNS and Numpyro wrapper #41
Replies: 3 comments 1 reply
-
|
Q0: In your code you're not jit compling. Try: ns = jaxns.nested_sampling.NestedSampler(log_lik, prior_chain,
num_live_points=prior_chain.U_ndims*500)
ns = jax.jit(ns)Q1: Just to be clear are you asking about Q2: Bug in 0.0.7, that I've fixed, and will release in 0.0.8. Q3: you never need to resample for corner plot. It does it for you. You just feed in R4: With nested sampling you can't really control the number of samples you get out, since they are weighted. I guess it would be possible to recalculate ESS after each run and make a stopping criterion that is based on having enough ESS. That's not too hard. I'll make an issue for it. |
Beta Was this translation helpful? Give feedback.
-
|
Thanks @Joshuaalbert for your kind answers. |
Beta Was this translation helpful? Give feedback.
-
|
@jecampagne anything unanswered here? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Here is an exercise with a fit over observations. Here is the snippet un JaxNS (0.0.7). At the end I point some questions:
Q0 : I am running on a CPU machine (ie not a GPU) and it takes 3min30sec to get the results. This is strange because using Numpyro wrapper takes about 30sec to get running the NestedSampler and get 5000 samples.
Then
crashes
TypeError: percentile requires ndarray or scalar arguments, got <class 'list'> at position 1.I get this error with other JaxNS exercices.I do not know if its ok or not, but after
I cannot use the corner plots given by the library, so I have written mine
So,
gives

which looks fine, at least it corresponds to the Numpy Wrapper results and also to NUTS sampling results too.
But (see also Q0 above):
Q1) I wander: from where JaxNS tells to give me 100_000 samples? while sometimes it is few hundred only?
Q2) Also why the percentile crashes in the JaxNS disgnotics (print/corner...)
Q3) why in the
https://github.com/Joshuaalbert/jaxns/blob/master/examples/jones_scalar_model.pyexample there is no needs to usejaxns.utils.resampleto get the samples for corner plot ???R4) By the way a
get_samples(random key, num_samples)a-la-Numpyro wrapping would be useful.Thanks
Beta Was this translation helpful? Give feedback.
All reactions