Skip to content

Conversation

@juanitorduz
Copy link
Collaborator

Closes #2088

@juanitorduz juanitorduz requested a review from fehiepsi October 27, 2025 12:22
@juanitorduz juanitorduz added bug Something isn't working enhancement New feature or request labels Oct 27, 2025
return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1))
# Handle edge cases where concentration1=1 and value=0, or concentration0=1 and value=1
# These cases would result in nan due to log(0) * 0 in the Dirichlet computation
log_prob = jnp.where(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting! could you check if grads w.r.t. value, concentration1, concentration0 are not NaN?

Copy link
Collaborator Author

@juanitorduz juanitorduz Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch 🙈 ! The gradients are NaN with this approach. After some investigation and talking with Claude Code 😅 . We arrived to a solution via https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html, see 13fc288 and 19f8be7

@juanitorduz juanitorduz requested a review from fehiepsi November 1, 2025 19:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Beta with concentration1=1 gives nan log_prob at value=0

3 participants