Skip to content

Commit 09fcc52

Browse files
Allow Layouts to propogate to local_load (#5219) (#5249)
recommit of triton-lang/triton#5219 While working on some higher dimension tensor kernels, I noticed poor performance due to the fact that layouts wouldn't propagate to local loads. Since we do allow layout folding with local store and local alloc, this seems like a bit of an oversight. The change gives a 40% speed improvement on certain kernels for NVidia GPUs. This also removes asserts in lowering for higher dimensional kernels. As far as I can tell, those restrictions aren't required in practice. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - [x] I have added tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices) <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [ ] I am not making a trivial change, such as fixing a typo in a comment. - [ ] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) Co-authored-by: Matthew Brookhart <[email protected]>
1 parent 4107453 commit 09fcc52

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,8 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
563563
}
564564
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
565565
triton::MakeRangeOp, triton::SplatOp, triton::HistogramOp,
566-
triton::gpu::LocalAllocOp, triton::gpu::LocalStoreOp>(op);
566+
triton::gpu::LocalAllocOp, triton::gpu::LocalLoadOp,
567+
triton::gpu::LocalStoreOp>(op);
567568
}
568569

569570
scf::ForOp replaceForOpWithNewSignature(

test/TritonGPU/combine.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,3 +2685,21 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
26852685
tt.return
26862686
}
26872687
}
2688+
2689+
// -----
2690+
2691+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 2, 1], order = [4, 0, 1, 2, 3]}>
2692+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 32, 1, 1], warpsPerCTA = [1, 1, 1, 1, 4], order = [4, 3, 2, 1, 0]}>
2693+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 2, 2, 1, 1], order = [4, 0, 3, 2, 1]}>
2694+
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 0, 1, 2, 3], hasLeadingOffset = false}>
2695+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:100", "triton_gpu.threads-per-warp" = 32 : i32} {
2696+
// CHECK-LABEL: lift_convert_to_local_load
2697+
// CHECK-NOT: convert_layout
2698+
// CHECK: tt.return
2699+
tt.func public @lift_convert_to_local_load(%arg0 : !triton_gpu.memdesc<2x1x32x4x4xi8, #shared, #triton_gpu.shared_memory, mutable>) -> tensor<2x4x32x1x4xi8, #blocked2> {
2700+
%1 = triton_gpu.local_load %arg0 : !triton_gpu.memdesc<2x1x32x4x4xi8, #shared, #triton_gpu.shared_memory, mutable> -> tensor<2x1x32x4x4xi8, #blocked>
2701+
%2 = tt.trans %1 {order = array<i32: 0, 3, 2, 1, 4>} : tensor<2x1x32x4x4xi8, #blocked> -> tensor<2x4x32x1x4xi8, #blocked1>
2702+
%3 = triton_gpu.convert_layout %2 : tensor<2x4x32x1x4xi8, #blocked1> -> tensor<2x4x32x1x4xi8, #blocked2>
2703+
tt.return %3 : tensor<2x4x32x1x4xi8, #blocked2>
2704+
}
2705+
}

0 commit comments

Comments
 (0)