Skip to content

[Performance] Improve the flash attention performance on bottom-up optimization pipeline #2177

@chengjunlu

Description

@chengjunlu

This issue is to track the new design required for flash-attention on bottom-up optimization pipeline.

Status

The most of the optimization passes has been finished and been checked in llvm-target branch. And all the tasks in the old issue #878 have been finished. The GEMM Triton kernel with block pointer syntax can get the 90% performance of the XeTLA version. There is a promising performance on the flash attention with block pointer by adding simply changes in RewriteBlockPointer pass.

New problem

There are two new problems found in the developing the bottom-up optimization pipeline:

  1. The FP8 flash attention has been supported and need to continue support it. We need some new implementation in lowering tt.load to support FP8 for flash attention.
  2. The RewriteBlockPointer pass generate the code not efficient tracked in the [Performance] Improve the code generated by the RewriteTensorPointer pass. #1766

Plan

To achieve the goals of both performance and functionality on bottom-up phase, we need a new implementation than it is original planed.

  1. Support to fallback to gather/scatter semantic memory accessing in lowering the tt.load operation with the block pointer as memory ptr. (Optionally to support fallback to Intel 1D block IO.)
  2. Remove RewriteBlockPointer totally as the memory accessing operation support to load the block pointer to any layout. (1st step.)

This design also can benefit the new feature as TMA descriptor in future.

Metadata

Metadata

Assignees

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions