Skip to content
Discussion options

You must be logged in to vote

The reason you're running into issues with arange and linspace is that you are constructing arrays with data-dependent shapes. In JIT-compiled or vmapped code, all arrays must have static shapes.

If you want a fast, vmappable factorial function, the best approach is probably to use jax.scipy.special.gammaln: this computes the log of the gamma function, which is related to the factorial function. For example:

import jax.numpy as jnp
from jax.scipy.special import gammaln

def factorial(n):
  return jnp.exp(gammaln(n + 1))

print(factorial(jnp.arange(10)))
# [1.0000005e+00 1.0000000e+00 1.9999986e+00 6.0000024e+00 2.4000023e+01
#  1.2000006e+02 7.2000177e+02 5.0400068e+03 4.0319996e+04 3.628…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@Peter-Vincent
Comment options

Answer selected by Peter-Vincent
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants