Skip to content

Commit e3467d8

Browse files
[mlir][nvgpu] Fix tma descriptor check (#152160)
The tma descriptor check does not appear to be correct, as it requires the last dimension of memref to be 128 bytes. However, the bytes of the last dimension can be equal to swizzle bytes.
1 parent f45f4ae commit e3467d8

File tree

3 files changed

+47
-33
lines changed

3 files changed

+47
-33
lines changed

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,19 @@ LogicalResult LdMatrixOp::verify() {
345345
// NVGPU_TmaAsyncLoadOp
346346
//===----------------------------------------------------------------------===//
347347

348+
unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
349+
switch (kind) {
350+
case TensorMapSwizzleKind::SWIZZLE_32B:
351+
return 32;
352+
case TensorMapSwizzleKind::SWIZZLE_64B:
353+
return 64;
354+
case TensorMapSwizzleKind::SWIZZLE_128B:
355+
return 128;
356+
default:
357+
return 0;
358+
}
359+
}
360+
348361
std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
349362
Operation *op, nvgpu::TensorMapDescriptorType descType,
350363
std::optional<MemRefType> memrefType = std::nullopt) {
@@ -373,10 +386,11 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
373386
descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
374387
unsigned lastDimensionByte =
375388
descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
376-
if (lastDimensionByte != kMaxTMALastdimByte)
389+
unsigned expectByte = getSwizzleBytes(descType.getSwizzle());
390+
if (lastDimensionByte != expectByte)
377391
return op->emitError() << "the tensormap descriptor must have last "
378392
"dimension of "
379-
<< kMaxTMALastdimByte << " bytes but it is "
393+
<< expectByte << " bytes but it is "
380394
<< lastDimensionByte << " bytes";
381395
}
382396

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -664,15 +664,15 @@ func.func @mbarrier_txcount_pred() {
664664

665665
// CHECK-LABEL: func @async_tma_load
666666
!tensorMap1d = !nvgpu.tensormap.descriptor<tensor = memref<128xf32,3>, swizzle=none, l2promo = none, oob = nan, interleave = none>
667-
!tensorMap2d = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
668-
!tensorMap3d = !nvgpu.tensormap.descriptor<tensor = memref<2x32x32xf32,3>, swizzle=swizzle_64b, l2promo = l2promo_64b, oob = zero, interleave = none>
667+
!tensorMap2d = !nvgpu.tensormap.descriptor<tensor = memref<32x8xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
668+
!tensorMap3d = !nvgpu.tensormap.descriptor<tensor = memref<2x32x16xf32,3>, swizzle=swizzle_64b, l2promo = l2promo_64b, oob = zero, interleave = none>
669669
!tensorMap4d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x32x32xf32,3>, swizzle=swizzle_128b,l2promo = l2promo_128b,oob = zero, interleave = none>
670670
!tensorMap5d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x2x32x32xf32,3>, swizzle=none, l2promo = none, oob = zero, interleave = none>
671671
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
672672
func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
673673
%buffer1d: memref<128xf32,3>,
674-
%buffer2d: memref<32x32xf32,3>,
675-
%buffer3d: memref<2x32x32xf32,3>,
674+
%buffer2d: memref<32x8xf32,3>,
675+
%buffer3d: memref<2x32x16xf32,3>,
676676
%buffer4d: memref<2x2x32x32xf32,3>,
677677
%buffer5d: memref<2x2x2x32x32xf32,3>,
678678
%mbarrier: !mbarrier) {
@@ -682,9 +682,9 @@ func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d
682682
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}]
683683
nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier[%c0] to %buffer1d : !tensorMap1d, !mbarrier -> memref<128xf32,3>
684684
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}]
685-
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d : !tensorMap2d, !mbarrier -> memref<32x32xf32,3>
685+
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d : !tensorMap2d, !mbarrier -> memref<32x8xf32,3>
686686
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}]
687-
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d : !tensorMap3d, !mbarrier -> memref<2x32x32xf32,3>
687+
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d : !tensorMap3d, !mbarrier -> memref<2x32x16xf32,3>
688688
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
689689
nvgpu.tma.async.load %tensorMap4d[%crd0, %crd1, %crd1, %crd0], %mbarrier[%c0] to %buffer4d : !tensorMap4d, !mbarrier -> memref<2x2x32x32xf32,3>
690690
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
@@ -708,8 +708,8 @@ func.func @async_tma_load_gpu_address_space(%tensorMap1d: !tensorMap1dgpuspace,
708708
// CHECK-LABEL: func @async_tma_load_pred
709709
func.func @async_tma_load_pred(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
710710
%buffer1d: memref<128xf32,3>,
711-
%buffer2d: memref<32x32xf32,3>,
712-
%buffer3d: memref<2x32x32xf32,3>,
711+
%buffer2d: memref<32x8xf32,3>,
712+
%buffer3d: memref<2x32x16xf32,3>,
713713
%buffer4d: memref<2x2x32x32xf32,3>,
714714
%buffer5d: memref<2x2x2x32x32xf32,3>,
715715
%mbarrier: !mbarrier,
@@ -720,9 +720,9 @@ func.func @async_tma_load_pred(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensor
720720
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}] predicate = %{{.*}}
721721
nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier[%c0] to %buffer1d, predicate = %p : !tensorMap1d, !mbarrier -> memref<128xf32,3>
722722
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}] predicate = %{{.*}}
723-
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d, predicate = %p : !tensorMap2d, !mbarrier -> memref<32x32xf32,3>
723+
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d, predicate = %p : !tensorMap2d, !mbarrier -> memref<32x8xf32,3>
724724
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}] predicate = %{{.*}}
725-
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d, predicate = %p : !tensorMap3d, !mbarrier -> memref<2x32x32xf32,3>
725+
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d, predicate = %p : !tensorMap3d, !mbarrier -> memref<2x32x16xf32,3>
726726
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] predicate = %{{.*}}
727727
nvgpu.tma.async.load %tensorMap4d[%crd0, %crd1, %crd1, %crd0], %mbarrier[%c0] to %buffer4d, predicate = %p : !tensorMap4d, !mbarrier -> memref<2x2x32x32xf32,3>
728728
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] predicate = %{{.*}}
@@ -734,7 +734,7 @@ func.func @async_tma_load_multicast(
734734
%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d,
735735
%tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d,
736736
%tensorMap5d: !tensorMap5d, %buffer1d: memref<128xf32,3>,
737-
%buffer2d: memref<32x32xf32,3>, %buffer3d: memref<2x32x32xf32,3>,
737+
%buffer2d: memref<32x8xf32,3>, %buffer3d: memref<2x32x16xf32,3>,
738738
%buffer4d: memref<2x2x32x32xf32,3>, %buffer5d: memref<2x2x2x32x32xf32,3>,
739739
%mbarrier: !mbarrier,
740740
%multicastMask: i16) {
@@ -744,9 +744,9 @@ func.func @async_tma_load_multicast(
744744
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}]
745745
nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier[%c0] to %buffer1d multicast_mask = %multicastMask : !tensorMap1d, !mbarrier -> memref<128xf32,3>
746746
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}]
747-
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d multicast_mask = %multicastMask : !tensorMap2d, !mbarrier -> memref<32x32xf32,3>
747+
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d multicast_mask = %multicastMask : !tensorMap2d, !mbarrier -> memref<32x8xf32,3>
748748
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}]
749-
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d multicast_mask = %multicastMask : !tensorMap3d, !mbarrier -> memref<2x32x32xf32,3>
749+
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d multicast_mask = %multicastMask : !tensorMap3d, !mbarrier -> memref<2x32x16xf32,3>
750750
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
751751
nvgpu.tma.async.load %tensorMap4d[%crd0, %crd1, %crd1, %crd0], %mbarrier[%c0] to %buffer4d multicast_mask = %multicastMask : !tensorMap4d, !mbarrier -> memref<2x2x32x32xf32,3>
752752
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
@@ -756,8 +756,8 @@ func.func @async_tma_load_multicast(
756756

757757
func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
758758
%buffer1d: memref<128xf32,3>,
759-
%buffer2d: memref<32x32xf32,3>,
760-
%buffer3d: memref<2x32x32xf32,3>,
759+
%buffer2d: memref<32x8xf32,3>,
760+
%buffer3d: memref<2x32x16xf32,3>,
761761
%buffer4d: memref<2x2x32x32xf32,3>,
762762
%buffer5d: memref<2x2x2x32x32xf32,3>) {
763763
%c0 = arith.constant 0 : index
@@ -766,9 +766,9 @@ func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2
766766
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}]
767767
nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0] : memref<128xf32,3> -> !tensorMap1d
768768
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}]
769-
nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1] : memref<32x32xf32,3> -> !tensorMap2d
769+
nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1] : memref<32x8xf32,3> -> !tensorMap2d
770770
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}]
771-
nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0] : memref<2x32x32xf32,3> -> !tensorMap3d
771+
nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0] : memref<2x32x16xf32,3> -> !tensorMap3d
772772
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
773773
nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0] : memref<2x2x32x32xf32,3> -> !tensorMap4d
774774
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
@@ -779,8 +779,8 @@ func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2
779779

780780
func.func @async_tma_store_predicate(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
781781
%buffer1d: memref<128xf32,3>,
782-
%buffer2d: memref<32x32xf32,3>,
783-
%buffer3d: memref<2x32x32xf32,3>,
782+
%buffer2d: memref<32x8xf32,3>,
783+
%buffer3d: memref<2x32x16xf32,3>,
784784
%buffer4d: memref<2x2x32x32xf32,3>,
785785
%buffer5d: memref<2x2x2x32x32xf32,3>,
786786
%p: i1) {
@@ -790,9 +790,9 @@ func.func @async_tma_store_predicate(%tensorMap1d: !tensorMap1d, %tensorMap2d: !
790790
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
791791
nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0], predicate = %p : memref<128xf32,3> -> !tensorMap1d
792792
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
793-
nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1], predicate = %p : memref<32x32xf32,3> -> !tensorMap2d
793+
nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1], predicate = %p : memref<32x8xf32,3> -> !tensorMap2d
794794
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
795-
nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0], predicate = %p : memref<2x32x32xf32,3> -> !tensorMap3d
795+
nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0], predicate = %p : memref<2x32x16xf32,3> -> !tensorMap3d
796796
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
797797
nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0], predicate = %p : memref<2x2x32x32xf32,3> -> !tensorMap4d
798798
// CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}

mlir/test/Dialect/NVGPU/invalid.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,14 +276,14 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc: !tR
276276

277277
// -----
278278

279-
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
279+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x8xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
280280
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
281-
func.func @tma_load_1(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
281+
func.func @tma_load_1(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x8xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
282282
%c0 = arith.constant 0 : index
283283
// Pass fine
284-
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
284+
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x8xf32,3>
285285
// expected-error @+1 {{Maximum 5 coordinates are supported.}}
286-
nvgpu.tma.async.load %desc[%c0, %c0, %c0, %c0, %c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
286+
nvgpu.tma.async.load %desc[%c0, %c0, %c0, %c0, %c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x8xf32,3>
287287
return
288288
}
289289
// -----
@@ -298,17 +298,17 @@ func.func @tma_load_2(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memr
298298
}
299299
// -----
300300

301-
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
301+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x8xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
302302
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
303-
func.func @tma_load_3(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
303+
func.func @tma_load_3(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x8xf32>, %mbarrier: !mbarrier) {
304304
%c0 = arith.constant 0 : index
305305
// expected-error @+1 {{the destination memref has incorrect address space, it must be shared memory address space}}
306-
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer3 : !desc, !mbarrier -> memref<32x32xf32>
306+
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer3 : !desc, !mbarrier -> memref<32x8xf32>
307307
return
308308
}
309309
// -----
310310

311-
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
311+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x8xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
312312
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
313313
func.func @tma_load_4(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
314314
%c0 = arith.constant 0 : index
@@ -319,7 +319,7 @@ func.func @tma_load_4(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memr
319319

320320
// -----
321321

322-
!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
322+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16,3>, swizzle=swizzle_128b, l2promo = none, oob = zero, interleave = none>
323323
func.func @tma_generate_descriptor_incorrect_last_dim(%b0 : index, %b1 : index, %mem : memref<*xf16>) {
324324
// expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 256 bytes}}
325325
%descA = nvgpu.tma.create.descriptor %mem box[%b0, %b1] : memref<*xf16> -> !desc
@@ -328,7 +328,7 @@ func.func @tma_generate_descriptor_incorrect_last_dim(%b0 : index, %b1 : index,
328328
// -----
329329

330330

331-
!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
331+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf32,3>, swizzle=swizzle_128b, l2promo = none, oob = zero, interleave = none>
332332
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
333333
func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc, %buffer2: memref<64x128xf32,3>, %mbarrier: !mbarrier) {
334334
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)