diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h index 61a57fb60bda4..ccee0de65caa5 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h @@ -46,8 +46,8 @@ constexpr unsigned kMaxTMATensorDimension = 5; /// to be traversed along each of the kMaxTMATensorDimension (tensorRank) /// dimensions, must be non-zero and less than or equal to 256. constexpr unsigned kMaxTMADimension = 256; -/// Last dimension of 2D+ TMA must be 128 bytes -constexpr unsigned kMaxTMALastdimByte = 128; +/// The bytes in the last dimension of the tensor map must be a multiple of 16. +constexpr unsigned kTMALastdimByte = 16; #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 8474244c7d7c0..697cb35a59a28 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -424,7 +424,7 @@ std::optional verifyTmaDescriptorWithMemref( int lastDimBytes = descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8; - if (lastDimBytes % 16 != 0) { + if (lastDimBytes % kTMALastdimByte != 0) { return op->emitError() << "the bytes in the last dimension of the tensor " "map must be a multiple of 16"; }