Skip to content
Discussion options

You must be logged in to vote

JAX doesn't have any built-in entropy function, but fortunately the formula for the normal distribution entropy is pretty straightforward to define yourself:

import jax.numpy as jnp

def normal_entropy(sigma):
  return 0.5 * (jnp.log(2 * jnp.pi * sigma ** 2) + 1)

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Answer selected by MyNameIsArko
Comment options

You must be logged in to vote
1 reply
@MyNameIsArko
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants