Skip to content

Commit afe088f

Browse files
committed
Simplify definition of jax.scipy.special.kl_div
1 parent 0739d52 commit afe088f

File tree

1 file changed

+1
-18
lines changed

1 file changed

+1
-18
lines changed

jax/_src/scipy/special.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -616,24 +616,7 @@ def kl_div(
616616
- :func:`jax.scipy.special.rel_entr`
617617
"""
618618
p, q = promote_args_inexact("kl_div", p, q)
619-
zero = _lax_const(p, 0.0)
620-
both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero))
621-
one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero))
622-
623-
safe_p = jnp.where(both_gt_zero_mask, p, 1)
624-
safe_q = jnp.where(both_gt_zero_mask, q, 1)
625-
626-
log_val = lax.sub(
627-
lax.add(
628-
lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q)),
629-
safe_q,
630-
),
631-
safe_p,
632-
)
633-
result = jnp.where(
634-
both_gt_zero_mask, log_val, jnp.where(one_zero_mask, q, np.inf)
635-
)
636-
return result
619+
return rel_entr(p, q) - p + q
637620

638621

639622
def rel_entr(

0 commit comments

Comments
 (0)