Skip to content

This change introduces a pure JAX implementation of flash attention to Maxtext, designed as a drop-in replacement for the existing Pallas kernel. In this cl we set up the stage by integrating it with maxtext in fsdp mode. We have plans for further optimizations to close the gap with pallas using different techniques such as:#2793

Open
copybara-service[bot] wants to merge 1 commit intomainfrom
test_834764107