Skip to content

Commit 6d08a39

Browse files
[mlir][nvgpu] Add tma last dim bytes check (#153451)
Add the check the number of bytes in the last dimension of Tma must be a multiple of 16.
1 parent 87de48d commit 6d08a39

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,12 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
422422
<< descMemref << " != " << dstMemref;
423423
}
424424

425+
int lastDimBytes =
426+
descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
427+
if (lastDimBytes % 16 != 0) {
428+
return op->emitError() << "the bytes in the last dimension of the tensor "
429+
"map must be a multiple of 16";
430+
}
425431
return std::nullopt;
426432
}
427433

mlir/test/Dialect/NVGPU/invalid.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,14 @@ func.func @check_matrixC_dim(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %ar
378378
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<4xf16>) -> vector<2x2xf16>
379379
return %d : vector<2x2xf16>
380380
}
381+
382+
// -----
383+
384+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x8xi8,3>, swizzle=none, l2promo = none, oob = zero, interleave = none>
385+
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
386+
func.func @tma_last_dim_bytes(%desc: !desc, %buffer: memref<32x8xi8,3>, %mbarrier: !mbarrier) {
387+
%c0 = arith.constant 0 : index
388+
// expected-error @+1 {{the bytes in the last dimension of the tensor map must be a multiple of 16}}
389+
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer : !desc, !mbarrier -> memref<32x8xi8,3>
390+
return
391+
}

0 commit comments

Comments
 (0)