Whats the best way to map across a factorial function? #8360
-
I want to construct a factorial function. Doing it recursively is the standard method, but doesn't play well when I try and map across it, since we need to use the value for control flow (see example code below)
Instead I tried the following code
But this still yields a very similar error
I've only included some of the elements of the DeviceArray here. Then of course there are the internal frames. I guess form this If I instead use Any ideas on how to get around this? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
The following code works
But I'm still wondering whats going under the hood in Thanks! |
Beta Was this translation helpful? Give feedback.
-
The reason you're running into issues with If you want a fast, vmappable factorial function, the best approach is probably to use
|
Beta Was this translation helpful? Give feedback.
The reason you're running into issues with
arange
andlinspace
is that you are constructing arrays with data-dependent shapes. In JIT-compiled or vmapped code, all arrays must have static shapes.If you want a fast, vmappable factorial function, the best approach is probably to use
jax.scipy.special.gammaln
: this computes the log of the gamma function, which is related to the factorial function. For example: