-
Notifications
You must be signed in to change notification settings - Fork 24
Open
Description
Hello guys,
I am running an model that functions similarly to a State Space Model (SSM) or Mamba, but in the frequency domain.
My bottleneck is a Causal Spectral Decomposition step.
Currently, I implement this in PyTorch using:
Output = IFFT( FFT(Input, n=2L) * FFT(Kernel, n=2L) ) (truncated to length L).
I need a 'FlashFFTConv' kernel with these specs:
- Input: [Batch, Groups, Length] (Real BF16)
- Dynamics: 129 Complex Frequencies per Group (Learned Decay + Phase).
- Sequence Length: Optimization target is 4096 to 8192.
- Total Complex Channels: ~516 (4 Groups * 129 Freqs).
The goal is to fuse the RFFT -> Multiply -> IRFFT sequence to avoid HBM (memory) read/writes, similar to how Mamba's selective_scan fuses the recurrence.
I saw on your main page that you could also help with custom ones, is this something you might be able to develop?
Thanks,
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels