More accurate experimental.jet
#9712
Unanswered
YouJiacheng
asked this question in
Ideas
Replies: 1 comment 1 reply
-
That's a great idea. (Sorry for not noticing this until now.) Want to make a PR to add this? (Also, thanks for all the amazing contributions here recently!) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Since
jit
has constant folding, we can compute factorial with brute force methodlax.prod(range(1, n + 1))
instead of approximatelax.exp(lax.lgamma(n+1.))
.Thanks to the great dynamic of python, we don't need to modify the source code of
jax
.With this modification, the result is
(DeviceArray(3., dtype=float32), [DeviceArray(6., dtype=float32), DeviceArray(10., dtype=float32)])
Without this modification, the result is
(DeviceArray(2.9999986, dtype=float32), [DeviceArray(5.999997, dtype=float32), DeviceArray(9.999995, dtype=float32)])
Look much better! (though maybe no difference for most application)
Beta Was this translation helpful? Give feedback.
All reactions