Skip to content

Commit f33a0a4

Browse files
authored
[mlir][nvgpu] Improve tensormap.descriptor Type Verifier (#77904)
This PR improves the verifier for the `nvgpu.tensormap.descriptor` type. The descriptor contains information for TMA, and the compile-time check ensures its restrictions, such as the last memory dimension being 128-byte. This prevents runtime crashes. See cuda driver for more explanation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
1 parent 25ab2fc commit f33a0a4

File tree

5 files changed

+78
-36
lines changed

5 files changed

+78
-36
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,16 @@ constexpr int kWarpSize = 32;
2525

2626
/// M size of wgmma.mma_async instruction
2727
constexpr int kWgmmaSizeM = 64;
28-
/// Maximum tensor dimension that TMA supports
29-
constexpr int kMaxTMATensorDimension = 5;
28+
29+
/// Maximum TMA tile dimension (tensorRank) must be non-zero and less than or
30+
/// equal to the maximum supported dimensionality of 5.
31+
constexpr unsigned kMaxTMATensorDimension = 5;
32+
/// Maximum TMA tile size (boxDim), which specifies number of elements
33+
/// to be traversed along each of the kMaxTMATensorDimension (tensorRank)
34+
/// dimensions, must be non-zero and less than or equal to 256.
35+
constexpr unsigned kMaxTMADimension = 256;
36+
/// Last dimension of 2D+ TMA must be 128 bytes
37+
constexpr unsigned kMaxTMALastdimByte = 128;
3038

3139
#define GET_ATTRDEF_CLASSES
3240
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,23 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
355355
if (!descMemref.hasStaticShape())
356356
return op->emitError() << "the tensor map descriptor must be static shaped";
357357

358+
for (auto dim : descMemref.getShape()) {
359+
if (dim <= 0 || dim > kMaxTMADimension) {
360+
return op->emitError() << "the tensor map descriptor must have "
361+
"dimensions between 1 and "
362+
<< kMaxTMADimension << " but it is " << dim;
363+
}
364+
}
365+
if (descMemref.getRank() > 1) {
366+
unsigned lastDimensionByte =
367+
descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
368+
if (lastDimensionByte != kMaxTMALastdimByte)
369+
return op->emitError() << "the tensormap descriptor must have last "
370+
"dimension of "
371+
<< kMaxTMALastdimByte << " bytes but it is "
372+
<< lastDimensionByte << " bytes";
373+
}
374+
358375
// No verification if memref type is not provided
359376
if (!memrefType.has_value())
360377
return std::nullopt;

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -799,10 +799,7 @@ func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
799799
}
800800

801801
!lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
802-
!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
803-
804-
!shmemlhs = memref<128x64xf16,3>
805-
!shmemrhs = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>
802+
!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
806803

807804
module @mymodule {
808805
// Dynamic Shared memory
@@ -811,17 +808,17 @@ module @mymodule {
811808
func.func @async_tma_load(%lhsTensorMap: !lhsTensorMap, %rhsTensorMap: !rhsTensorMap, %mbarrier: !barrierType) {
812809
%c0 = arith.constant 0 : index
813810
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
814-
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to !shmemlhs
815-
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2,64,128], strides: [8192,128,1] : memref<0xf16, 3> to memref<2x64x128xf16,3>
816-
%rhsShmem3 = memref.subview %rhsShmem2[1,0,0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16,3> to memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3>
817-
%rhsShmem = memref.subview %rhsShmem3[0,0,0][1, 64, 128][1, 1, 1] : memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3> to !shmemrhs
811+
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
812+
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 64], strides: [4096, 64, 1] : memref<0xf16, 3> to memref<4x64x64xf16,3>
813+
%rhsShmem3 = memref.subview %rhsShmem2[2, 0, 0][1, 64, 64][1, 1, 1] : memref<4x64x64xf16,3> to memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3>
814+
%rhsShmem = memref.subview %rhsShmem3[0, 0, 0][1, 64, 64][1, 1, 1] : memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3> to memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
818815
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global
819-
nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
816+
nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> memref<128x64xf16,3>
820817
// CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
821818
// CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
822819
// CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64)
823820
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
824-
nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs
821+
nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
825822
return
826823
}
827824
}

mlir/test/Dialect/NVGPU/invalid.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,23 @@ func.func @tma_load_4(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memr
316316
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer1 : !desc, !mbarrier -> memref<128xf32,3>
317317
return
318318
}
319+
320+
// -----
321+
322+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
323+
func.func @tma_generate_descriptor_incorrect_last_dim(%b0 : index, %b1 : index, %mem : memref<*xf16>) {
324+
// expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 256 bytes}}
325+
%descA = nvgpu.tma.create.descriptor %mem box[%b0, %b1] : memref<*xf16> -> !desc
326+
return
327+
}
328+
// -----
329+
330+
331+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
332+
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
333+
func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc, %buffer2: memref<64x128xf32,3>, %mbarrier: !mbarrier) {
334+
%c0 = arith.constant 0 : index
335+
// expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 512 bytes}}
336+
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<64x128xf32,3>
337+
return
338+
}

mlir/test/Dialect/NVGPU/tmaload-transform.mlir

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,35 @@
33
// RUN: -test-transform-dialect-erase-schedule \
44
// RUN: | FileCheck %s
55

6-
memref.global "private" @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
7-
memref.global "private" @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
6+
memref.global "private" @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
7+
memref.global "private" @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
88

99
// CHECK-LABEL: func.func @main()
1010
func.func @main() {
1111
%c1 = arith.constant 1 : index
1212
%c128 = arith.constant 128 : index
1313

1414
%0 = gpu.wait async
15-
%memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
16-
%memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
15+
%memref, %asyncToken = gpu.alloc async [%0] () : memref<64x32xf32>
16+
%memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x32xf32>
1717

18-
// CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32> to memref<*xf32>
18+
// CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32> to memref<*xf32>
1919
// CHECK: %[[c64:.*]] = arith.constant 64 : index
20-
// CHECK: %[[c8:.*]] = arith.constant 8 : index
21-
// CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]]
22-
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
23-
// CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32> to memref<*xf32>
20+
// CHECK: %[[c32:.*]] = arith.constant 32 : index
21+
// CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c32]]]
22+
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
23+
// CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32> to memref<*xf32>
2424
// CHECK: %[[c8_2:.*]] = arith.constant 8 : index
25-
// CHECK: %[[c128_2:.*]] = arith.constant 128 : index
26-
// CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]]
27-
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
25+
// CHECK: %[[c32_2:.*]] = arith.constant 32 : index
26+
// CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c32_2]]]
27+
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
2828
// CHECK: gpu.launch
2929
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
3030
threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
31-
// CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
32-
// CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
33-
%out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
34-
%out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
31+
// CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
32+
// CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
33+
%out = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
34+
%out_1 = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
3535

3636
// CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>
3737
// CHECK: nvgpu.mbarrier.init %[[B]][%{{.*}}], %{{.*}} : <memorySpace = #gpu.address_space<workgroup>
@@ -45,18 +45,18 @@ func.func @main() {
4545
//
4646
// CHECK: %[[c0_7:.*]] = arith.constant 0 : index
4747
// CHECK: nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]][%{{.*}}] to %[[G1]]
48-
// CHECK-SAME: : <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>,
48+
// CHECK-SAME: : <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>,
4949
// CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
50-
// CHECK-SAME: -> memref<64x8xf32, #gpu.address_space<workgroup>>
50+
// CHECK-SAME: -> memref<64x32xf32, #gpu.address_space<workgroup>>
5151
//
5252
// CHECK: %[[c0_8:.*]] = arith.constant 0 : index
5353
// CHECK: nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]][%{{.*}}] to %[[G2]]
54-
// CHECK-SAME: : <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>,
54+
// CHECK-SAME: : <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>,
5555
// CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
56-
// CHECK-SAME: -> memref<8x128xf32, #gpu.address_space<workgroup>>
56+
// CHECK-SAME: -> memref<8x32xf32, #gpu.address_space<workgroup>>
5757
//
58-
// CHECK: %[[c6144:.*]] = arith.constant 6144 : index
59-
// CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c6144]] : <memorySpace = #gpu.address_space<workgroup>
58+
// CHECK: %[[c9216:.*]] = arith.constant 9216 : index
59+
// CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c9216]] : <memorySpace = #gpu.address_space<workgroup>
6060
// CHECK: } else {
6161
// CHECK: %[[c0_7:.*]] = arith.constant 0 : index
6262
// CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
@@ -67,8 +67,8 @@ func.func @main() {
6767
// CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
6868

6969
/// Both copies are matched and end up in the same async group.
70-
linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, #gpu.address_space<workgroup>>)
71-
linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, #gpu.address_space<workgroup>>)
70+
linalg.copy ins(%memref: memref<64x32xf32>) outs(%out: memref<64x32xf32, #gpu.address_space<workgroup>>)
71+
linalg.copy ins(%memref_1: memref<8x32xf32>) outs(%out_1: memref<8x32xf32, #gpu.address_space<workgroup>>)
7272

7373
gpu.terminator
7474
}

0 commit comments

Comments
 (0)