Why does rfft transpose use naive FFT instead of irfft? #34553
Replies: 3 comments 15 replies
-
|
That sounds plausible to me. I suspect it's simply we didn't realize that when writing the code. PRs welcome! |
Beta Was this translation helpful? Give feedback.
-
|
Hello! See my comment above ^. To get discussion going, here is a minimal example of running the new transpose rule and the corresponding HLO dump. I apologize for the dump as I am still learning how to read these; could someone help out? Hoping to get some discussion going to make sure the new implementation has optimal memory performance. In particular, we should make sure that no additional arrays besides the input are allocated, especially which increase in size under vmap. import jax, jax.numpy as jnp
def fn(arr):
return jnp.fft.rfftn(arr)
arr = jax.random.normal(key=jax.random.key(1234), shape=(10, 10))
primals_out, vjp_fn = jax.vjp(fn, arr)
vjp_fn_jit = jax.jit(vjp_fn)
print(vjp_fn_jit.lower(primals_out).compile().as_text())Please let me know if I am understanding correctly when the new transpose will be invoked. |
Beta Was this translation helpful? Give feedback.
-
|
I also tried an implementation as @unalmis suggests: def _rfft_transpose(t, fft_lengths):
if fft_lengths[-1] % 2 == 0:
t = t.at[..., 1:-1].divide(2.0, indices_are_sorted=True, unique_indices=True)
else:
t = t.at[..., 1:].divide(2.0, indices_are_sorted=True, unique_indices=True)
N = math.prod(fft_lengths)
out = N * fft(lax.conj(t), FftType.IRFFT, fft_lengths)
assert out.dtype == _real_dtype(t.dtype), (out.dtype, t.dtype)
return outInterestingly, this yielded an HLO dump that seems worse. Here it is under vmap: There is a constant that gets broadcasted to the size of the input array? I find this confusing if anyone has insight. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I noticed that
_rfft_transposeinjax/_src/lax/fft.pycomputes the transpose by:fft(x) + slicelinear_transposeon that naive implementationThis effectively runs a full complex FFT. However, the transpose can be computed directly using
irfftwith a mask:Since
irfftexploits Hermitian symmetry, it's roughly 2x faster and uses half the memory of a full FFT.The comment mentions avoiding "manually building up larger twiddle matrices," but the mask-based approach doesn't require twiddle matrices—just element-wise division by
[1, 2, ..., 2, 1].Is there a correctness concern I'm missing, or would a PR switching to the
irfft-based transpose be welcome?Beta Was this translation helpful? Give feedback.
All reactions