Skip to content

Commit e1162ee

Browse files
Enable prefetching for Hopper mixed-precision gemm [wgmma] (#6196)
@ThomasRaoux This is the PR for enabling prefetching for hopper mixed-precision gemm. I will explain it in detail in the following and want to discuss with you how to merge this PR. ### What is this PR about This PR does automatically prefetching for Hopper mixed-precision GEMM, which is similar to what current prefetch pass do for tt.dot. Therefore, it can be potentially merge into Prefetch pass [tritongpu-prefetch](https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp). ### How it does prefetching. Basically, it first detects the pattern (just an example below): ``` for k = 0 to K: tileA = subview[Asmem[i8x3x128x64] -> i8x128x64] tileA = Localload(i8x128x64) tileA = dequantize(i8x128x64 -> bf16x128x64) tileB = subview(Bsmem[bf16x3x128x64] -> bf16x128x64) Acc += wgmma(tileA, tileB) // tileA in register, tileB is smem wgmma_wait(0) endfor ``` The above pattern is typically generated by the tritongpu-pipeline pass. For automatic prefetching on Hopper, this PR will generate the following code: ``` for k = 0 to K: tileA = subview[Asmem[i8x3x128x64] -> i8x128x64] subtileA_1 = subview and localload [tileA, i8x128x16] // first 16 columns subtileA_2 = subview and localload [tileA, i8x128x16] // second 16 columns subtileA_3 = subview and localload [tileA, i8x128x16] // third 16 columns subtileA_4 = subview and localload [tileA, i8x128x16] // fourth 16 columns subtileA_1 = dequantize(subtileA_1) Acc += wgmma(subtileA_1 , subtileB_1) subtileA_2 = dequantize(subtileA_2 ) Acc += wgmma(subtileA_2 , subtileB_2) subtileA_3 = dequantize(subtileA_3) Acc += wgmma(subtileA_3 , subtileB_3) subtileA_4 = dequantize(subtileA_4) Acc += wgmma(subtileA_4, subtileB_4) .......... wgmma_wait(0) // sink wait(0) as much as possible until meet dependence such as yieldOp or if region endfor ``` The benefits the prefetching are that (1) the latency of localload can be hidden. For example, when doing `subtileA_1 = dequantize(subtileA_1)` . The latency of `subtileA_1 = localload[tileA]` can be hidden by following three localloads for other subtiles. (2) dequantize and wgmma can be interleaving to better utilize the gpu resources. (3) sink wait(0) as much as possible so hide some latency of async_load. ### What is the performance benefit We tested it on H100 80GB HBM, using data format int8xbf16->bf16. The performance is shown in the figure below. The number may vary on different hopper platform and different run (as noises exist). 1-6% speedup on small shape and 20-30% speedup on large shape: ![mixe-gemm-hopper](https://github.com/user-attachments/assets/4f88d200-38ff-41aa-b1df-028b20c82a09) Theoretically, it should work for the case if operand B is from a subview. ### How it can be merged This PR can be merged into prefetch pass https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp#L49C1-L49C19 Theoretically, we can create template based prefetcher, instantiate prefetcher for` tt.dot` and `ttng.warp_group_dot` seperately. If we decide how to merge, later, I can create lit test for it. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x ] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Thomas Raoux <[email protected]>
1 parent db0c34c commit e1162ee

File tree

7 files changed

+810
-1
lines changed

7 files changed

+810
-1
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,23 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
218218
"mlir::arith::ArithDialect"];
219219
}
220220

221+
def TritonGPUWGMMAPrefetch : Pass<"tritongpu-wgmma-prefetch", "mlir::ModuleOp"> {
222+
let summary = "prefetch for wgmma mixed precision";
223+
224+
let description = [{
225+
This pass attempts to prefetch from shared memory for mixed-precision
226+
wgmma when operand A is in the shared memory and needs to be loaded
227+
to the local registers.
228+
}];
229+
230+
let dependentDialects = [ "mlir::triton::gpu::TritonGPUDialect",
231+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
232+
"mlir::scf::SCFDialect",
233+
"mlir::arith::ArithDialect"];
234+
}
235+
236+
237+
221238
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
222239
let summary = "accelerate matmul";
223240

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_triton_library(TritonGPUTransforms
2626
Pipeliner/PipeliningUtility.cpp
2727
Pipeliner/Schedule.cpp
2828
Prefetch.cpp
29+
WGMMAPrefetch.cpp
2930
RemoveLayoutConversions.cpp
3031
ReorderInstructions.cpp
3132
CoalesceAsyncCopy.cpp

0 commit comments

Comments
 (0)