Skip to content

Commit b3e233e

Browse files
authored
[GLUON] Improve layout check message (#8456)
The following script will crash the compiler with ambiguous error message: ``` import torch import triton import triton.language as tl from triton.experimental import gluon from triton.experimental.gluon import language as ttgl from triton.experimental.gluon.language.nvidia import blackwell from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma @gluon.jit def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr): # smem = ttgl.allocate_shared_memory(ttgl.float32, [XBLOCK, XBLOCK], smem_layout) smem = ttgl.allocate_shared_memory(ttgl.float32, [XBLOCK, XBLOCK], mbarrier.MBarrierLayout()) bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) mbarrier.init(bar, count=1) mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float32.primitive_bitwidth // 8) tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) mbarrier.wait(bar, 0) mbarrier.invalidate(bar) tma.async_copy_shared_to_global(input_desc, [0, 0], smem) tma.store_wait(0) def test_async_tma(): input = torch.randn((1024, 1024), device="cuda", dtype=torch.float32) XBLOCK = 128 shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2) input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [XBLOCK, XBLOCK], shared_layout) h = async_tma_kernel[(1, )](input_desc, XBLOCK, shared_layout, num_warps=1) print(f"input_desc: {input_desc}") test_async_tma() ``` The reason is "smem" mis-use gluon's layout. This PR adds check at gluon language level to emit better crash message. Adding check in async_tma_copy_global_to_local's verifier however, will crash triton's lowering path (triton-nvidia-tma-lowering), so I think the check should be only added for gluon. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [ x] I am not making a trivial change, such as fixing a typo in a comment. - [x ] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [] This PR does not need a test because `error message is improved`. - Select one of the following. - [x ] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 814b862 commit b3e233e

File tree

6 files changed

+45
-19
lines changed

6 files changed

+45
-19
lines changed

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ LogicalResult AsyncTMACopyGlobalToLocalOp::verify() {
208208
return emitOpError("TMA copies must have between 1 and 5 coordinates");
209209
if (!getResult().getType().getMutableMemory())
210210
return emitOpError("Cannot store into immutable memory");
211+
if (!isa<NVMMASharedEncodingAttr>(getResult().getType().getEncoding()))
212+
return emitOpError("TMA result must have NVMMA shared layout");
211213
return success();
212214
}
213215

test/Hopper/WarpSpecialization/ws_code_partition.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
264264

265265
// -----
266266

267-
// CHECK-DAG: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
267+
// CHECK-DAG: #[[$SHARED:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = true, elementBitWidth = 32, CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
268268
// CHECK-DAG: #[[$SHARED1:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
269269
// CHECK-LABEL: @_fbgemm_grouped_gemm_fp8_rowwise_ws
270270
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x64x64xf8E4M3FN, #[[$SHARED1]], #smem, mutable>
@@ -275,7 +275,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
275275
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
276276
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 32]}>
277277
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
278-
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
278+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = true, elementBitWidth = 32, CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
279279
#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8}>
280280
#smem = #ttg.shared_memory
281281
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {

test/TritonGPU/pipeline-lower-loop.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1384,7 +1384,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
13841384
#linear2 = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
13851385
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>
13861386
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8, CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [2, 1, 0]}>
1387-
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}>
1387+
#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, CTAsPerCGA = [1, 1, 1, 1, 1], CTASplitNum = [1, 1, 1, 1, 1], CTAOrder = [4, 3, 2, 1, 0]}>
13881388
#shared3 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8, fp4Padded = true}>
13891389
#smem = #ttg.shared_memory
13901390
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 256, colStride = 1>

test/TritonNvidiaGPU/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,21 @@ tt.func @wgmma(%a: tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kW
8989
tt.return
9090
}
9191
}
92+
93+
// -----
94+
95+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
96+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 2, 0]}>
97+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
98+
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
99+
#smem = #ttg.shared_memory
100+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
101+
tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc<tensor<1x256x32xf32, #shared>>) -> tensor<256x32xf32, #blocked> {
102+
%true = arith.constant true
103+
%c32_i32 = arith.constant 32 : i32
104+
%0 = ttg.local_alloc : () -> !ttg.memdesc<256x32xf32, #shared1, #smem, mutable>
105+
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable>
106+
// expected-error @below {{TMA result must have NVMMA shared layout}}
107+
ttng.async_tma_copy_global_to_local %arg0[%c32_i32, %c32_i32, %c32_i32] %0, %1, %true : !tt.tensordesc<tensor<1x256x32xf32, #shared>>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<256x32xf32, #shared1, #smem, mutable>
108+
}
109+
}

test/TritonNvidiaGPU/membar.mlir

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,25 +82,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
8282
// -----
8383

8484

85-
#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
85+
8686
#blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
87+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
88+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
8789
#smem = #ttg.shared_memory
8890
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
89-
tt.func public @tma_load(%arg0: !tt.tensordesc<tensor<128x64xf16, #shared0>>, %arg1: i32) -> tensor<128x64xf16, #blocked0> {
91+
tt.func public @tma_load(%arg0: !tt.tensordesc<tensor<128x64xf16, #shared>>, %arg1: i32) -> tensor<128x64xf16, #blocked0> {
9092
// CHECK-LABEL: tma_load
9193
// CHECK: local_dealloc
9294
// CHECK-NEXT: local_alloc
9395
// CHECK-NEXT: local_alloc
94-
// CHECK-NEXT: ttg.local_barrier
96+
// CHECK-NEXT: local_barrier
9597
// CHECK-NEXT: init_barrier
9698
%cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0>
97-
%alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #smem, mutable>
98-
ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #smem, mutable>
99-
%l = tt.descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc<tensor<128x64xf16, #shared0>> -> tensor<128x64xf16, #blocked0>
99+
%alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared1, #smem, mutable>
100+
ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared1, #smem, mutable>
101+
%l = tt.descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked0>
100102
tt.return %l : tensor<128x64xf16, #blocked0>
101103
}
102104
}
103105

106+
104107
// -----
105108

106109
#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>

test/TritonNvidiaGPU/tma_lowering.mlir

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ tt.func @tma_scatter(%arg0: !tt.tensordesc<tensor<1x128xbf16, #nvmma_128>>, %arg
9090
// -----
9191

9292
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
93-
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 2, 0]}>
94-
// CHECK: #[[$SHARED:.+]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
93+
#nvmma_128 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
94+
// CHECK: #[[$NVMMA:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
9595
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
9696
// CHECK-LABLE: @rank_reducing_load
97-
tt.func public @rank_reducing_load(%arg0: !tt.tensordesc<tensor<1x256x32xf32, #shared>>) -> tensor<256x32xf32, #blocked> {
97+
tt.func public @rank_reducing_load(%arg0: !tt.tensordesc<tensor<1x256x32xf32, #nvmma_128>>) -> tensor<256x32xf32, #blocked> {
9898
%c32_i32 = arith.constant 32 : i32
99-
// CHECK: %[[A:.+]] = ttg.local_alloc : () -> !ttg.memdesc<256x32xf32, #[[$SHARED]], #smem, mutable>
99+
// CHECK: %[[A:.+]] = ttg.local_alloc : () -> !ttg.memdesc<256x32xf32, #[[$NVMMA]], #smem, mutable>
100100
// CHECK: tng.async_tma_copy_global_to_local %{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}] %[[A]],
101-
%l = tt.descriptor_load %arg0[%c32_i32, %c32_i32, %c32_i32] : !tt.tensordesc<tensor<1x256x32xf32, #shared>> -> tensor<256x32xf32, #blocked>
101+
%l = tt.descriptor_load %arg0[%c32_i32, %c32_i32, %c32_i32] : !tt.tensordesc<tensor<1x256x32xf32, #nvmma_128>> -> tensor<256x32xf32, #blocked>
102102
tt.return %l : tensor<256x32xf32, #blocked>
103103
}
104104
}
@@ -107,16 +107,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
107107

108108
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
109109
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
110+
#nvmma_128 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
110111
#smem = #ttg.shared_memory
112+
// CHECK: #[[$NVMMA:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
111113
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
112114
// CHECK-LABEL: @tma_load_alloc_user
113-
tt.func public @tma_load_alloc_user(%arg0: !tt.tensordesc<tensor<64x64xf32, #shared>>, %arg1: i32) -> (tensor<64x64xf32, #blocked>, !ttg.memdesc<64x64xf32, #shared, #smem, mutable>) {
114-
%0 = tt.descriptor_load %arg0[%arg1, %arg1, %arg1] : !tt.tensordesc<tensor<64x64xf32, #shared>> -> tensor<64x64xf32, #blocked>
115-
// CHECK: %[[A:.+]] = ttg.local_alloc : () -> !ttg.memdesc<64x64xf32
116-
// CHECK: ttng.async_tma_copy_global_to_local {{.*}} %[[A]],
115+
tt.func public @tma_load_alloc_user(%arg0: !tt.tensordesc<tensor<64x64xf32, #nvmma_128>>, %arg1: i32) -> (tensor<64x64xf32, #blocked>, !ttg.memdesc<64x64xf32, #shared, #smem, mutable>) {
116+
%0 = tt.descriptor_load %arg0[%arg1, %arg1, %arg1] : !tt.tensordesc<tensor<64x64xf32, #nvmma_128>> -> tensor<64x64xf32, #blocked>
117+
// CHECK: %[[A:.+]] = ttg.local_alloc : () -> !ttg.memdesc<64x64xf32, #[[$NVMMA]], #smem, mutable>
118+
// CHECK: tng.async_tma_copy_global_to_local %{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}] %[[A]],
117119
%1 = ttg.local_alloc %0 : (tensor<64x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable>
118120
// CHECK: %[[L:.+]] = ttg.local_load %[[A]] :
119-
// CHECK: tt.return %[[L]], %[[A]] :
121+
// CHECK: %[[S:.+]] = ttg.local_alloc %[[L]] :
122+
// CHECK: tt.return %[[L]], %[[S]] :
120123
tt.return %0, %1 : tensor<64x64xf32, #blocked>, !ttg.memdesc<64x64xf32, #shared, #smem, mutable>
121124
}
122125
}

0 commit comments

Comments
 (0)