This change introduces a pure JAX implementation of flash attention to Maxtext, designed as a drop-in replacement for the existing Pallas kernel. In this cl we set up the stage by integrating it with maxtext in fsdp mode. We have plans for further optimizations to close the gap with pallas using different techniques such as:#2793
Open
copybara-service[bot] wants to merge 1 commit intomainfrom
Open
This change introduces a pure JAX implementation of flash attention to Maxtext, designed as a drop-in replacement for the existing Pallas kernel. In this cl we set up the stage by integrating it with maxtext in fsdp mode. We have plans for further optimizations to close the gap with pallas using different techniques such as:#2793copybara-service[bot] wants to merge 1 commit intomainfrom
copybara-service[bot] wants to merge 1 commit intomainfrom