Skip to content
Discussion options

You must be logged in to vote

it meant the output of the logdensity function

def normal_post(theta,xobs):
    return jnp.sum(jax.scipy.stats.norm.logpdf(xobs,loc = theta))

Returns float32 initially, and after 1 step it returns float64.
Did you try also casting xobs to np.float64?

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@yyang97
Comment options

@junpenglao
Comment options

Answer selected by yyang97
@yyang97
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants