Replies: 2 comments
-
In the current version of JAX, import jax
import jax.numpy as jnp
def g(x):
return jnp.fft.fftn(jnp.fft.ifftn(x))
print(jax.jit(g).lower(jnp.arange(100.0)).compile().as_text())
|
Beta Was this translation helpful? Give feedback.
0 replies
-
Thank you! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
My understanding is the XLA compiler under jit turns
jnp.exp(jnp.log(x))
intox
. Does it do the same forjnp.fft.fftn(jnp.fft.ifftn(x))
?Beta Was this translation helpful? Give feedback.
All reactions