A ComfyUI node that patches model attention to use PyTorch Flex Attention (torch.nn.attention.flex_attention) with torch.compile. No extra dependencies — works on any GPU supported by torch.compile.
- PyTorch 2.5+ (flex_attention is built-in)
- Any CUDA GPU supported by
torch.compile
Place the Flex Attention node between your model/LoRA loader and the KSampler in your workflow.
Unlike Flash Attention 4 (which requires Blackwell GPUs), Flex Attention works on any modern NVIDIA GPU. It uses PyTorch's native torch.compile to fuse the attention kernel — no graph breaks required.
- Compiles
torch.nn.attention.flex_attentionviatorch.compile - First call triggers compilation (expect a one-time warmup delay)
- Subsequent calls use the compiled kernel
- Falls back to PyTorch SDPA when attention masks are present