Skip to content

Custom Kernel Request - Fused Causal Convolution (FFT-based) kernel #4

@Puddings22

Description

@Puddings22

Hello guys,

I am running an model that functions similarly to a State Space Model (SSM) or Mamba, but in the frequency domain.

My bottleneck is a Causal Spectral Decomposition step.

Currently, I implement this in PyTorch using:
Output = IFFT( FFT(Input, n=2L) * FFT(Kernel, n=2L) ) (truncated to length L).

I need a 'FlashFFTConv' kernel with these specs:

  • Input: [Batch, Groups, Length] (Real BF16)
  • Dynamics: 129 Complex Frequencies per Group (Learned Decay + Phase).
  • Sequence Length: Optimization target is 4096 to 8192.
  • Total Complex Channels: ~516 (4 Groups * 129 Freqs).

The goal is to fuse the RFFT -> Multiply -> IRFFT sequence to avoid HBM (memory) read/writes, similar to how Mamba's selective_scan fuses the recurrence.

I saw on your main page that you could also help with custom ones, is this something you might be able to develop?

Thanks,

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions