diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 7cb4b5c346ad9..6b462de144d1f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1949,6 +1949,74 @@ def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap", }]; } +def NVVM_CpAsyncBulkTensorPrefetchOp : + NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> { + let arguments = (ins + LLVM_AnyPointer:$tmaDescriptor, + Variadic:$coordinates, + Variadic:$im2colOffsets, + Optional:$l2CacheHint); + + let description = [{ + Initiates an asynchronous prefetch operation on the tensor data from global + memory to L2 cache. + + The Op has two modes: + 1) Tiled Mode: It's the default mode. The source multi-dimensional tensor + layout is preserved at the destination. + + 2) Im2col Mode: This mode is used when `im2colOffsets` operands are present. + the elements in the Bounding Box of the source tensor are rearranged into + columns at the destination. In this mode, the tensor has to be at least + 3-dimensional. + + The `l2CacheHint` operand is optional, and it is used to specify cache + eviction policy that may be used during the memory access. + + [For more information, see PTX ISA] + (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor) + }]; + + let assemblyFormat = [{ + $tmaDescriptor `,` + `box` `[`$coordinates `]` + (`im2col` `[` $im2colOffsets^ `]` )? + (`l2_cache_hint` `=` $l2CacheHint^ )? + attr-dict `:` type($tmaDescriptor) + }]; + + let extraClassDeclaration = [{ + static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col); + }]; + + let hasVerifier = 1; + + string llvmBuilder = [{ + // Arguments to the intrinsic: + // tmaDesc, tensorDims, im2colOffsets + // cache_hint(if applicable) and flag(boolean) + llvm::SmallVector translatedOperands; + translatedOperands.push_back($tmaDescriptor); + + for (auto v : op.getCoordinates()) + translatedOperands.push_back(moduleTranslation.lookupValue(v)); + + for (auto v : op.getIm2colOffsets()) + translatedOperands.push_back(moduleTranslation.lookupValue(v)); + + llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext(); + auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64)); + + bool isCacheHint = op.getL2CacheHint() ? true : false; + translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef); + translatedOperands.push_back(builder.getInt1(isCacheHint)); + + auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID( + op.getCoordinates().size(), op.getIm2colOffsets().size() > 0); + createIntrinsicCall(builder, intId, translatedOperands); + }]; +} + //===----------------------------------------------------------------------===// // NVVM Wgmma Ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 5ab64ea1b2097..d28194d5c0029 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -75,22 +75,32 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) { void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } -LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { - if (getCoordinates().empty() || getCoordinates().size() > 5) - return emitError("expects coordinates between 1 to 5 dimension"); - - // Check for im2col mode - if (!getIm2colOffsets().empty()) { - if (getCoordinates().size() < 3) +// This verifier is shared across: +// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and +// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops. +static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims, + size_t numIm2ColOffsets, + Location loc) { + if (tensorDims < 1 || tensorDims > 5) + return emitError(loc, "expects coordinates between 1 to 5 dimension"); + + if (numIm2ColOffsets) { + if (tensorDims < 3) return emitError( + loc, "to use im2col mode, the tensor has to be at least 3-dimensional"); - if (getCoordinates().size() != (getIm2colOffsets().size() + 2)) + if (tensorDims != (numIm2ColOffsets + 2)) return emitError( - "im2col offsets must be 2 less than number of coordinates"); + loc, "im2col offsets must be 2 less than number of coordinates"); } return success(); } +LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { + return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), + getIm2colOffsets().size(), getLoc()); +} + LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() { if (getCoordinates().size() > 5) return emitError("Maximum 5 coordinates and dimension is supported."); @@ -108,6 +118,11 @@ LogicalResult CpAsyncOp::verify() { return success(); } +LogicalResult CpAsyncBulkTensorPrefetchOp::verify() { + return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), + getIm2colOffsets().size(), getLoc()); +} + // Given the element type of an operand and whether or not it is an accumulator, // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the // operand's element type. @@ -1055,6 +1070,30 @@ LogicalResult NVVM::BarrierOp::verify() { return success(); } +llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, + bool isIm2Col) { + switch (tensorDims) { + case 1: + return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d; + case 2: + return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d; + case 3: + return isIm2Col + ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d + : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d; + case 4: + return isIm2Col + ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d + : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d; + case 5: + return isIm2Col + ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d + : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d; + default: + llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp."); + } +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 0e563808da970..58282adf4dda8 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -30,4 +30,28 @@ llvm.func @nvvm_fence_proxy_release() { // expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support tensormap for to_proxy attribute}} nvvm.fence.proxy.release #nvvm.mem_scope from_proxy=#nvvm.proxy_kind to_proxy=#nvvm.proxy_kind llvm.return -} \ No newline at end of file +} + +// ----- + +llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) { + // expected-error @below {{expects coordinates between 1 to 5 dimension}} + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr + llvm.return +} + +// ----- + +llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) { + // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}} + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr + llvm.return +} + +// ----- + +llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) { + // expected-error @below {{im2col offsets must be 2 less than number of coordinates}} + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 75ce958b43fd3..e5ea03ff7e001 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -715,3 +715,65 @@ llvm.func @nvvm_breakpoint() { nvvm.breakpoint llvm.return } + +// ----- + +// CHECK-LABEL: @tma_prefetch_1d +llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) { + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 undef, i1 false) + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true) + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr + llvm.return +} + +// CHECK-LABEL: @tma_prefetch_2d +llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) { + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false) + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true) + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr + llvm.return +} + +// CHECK-LABEL: @tma_prefetch_3d +llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) { + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false) + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true) + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr + + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false) + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true) + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr + llvm.return +} + +// CHECK-LABEL: @tma_prefetch_4d +llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) { + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false) + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true) + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr + + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false) + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true) + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr + llvm.return +} + +// CHECK-LABEL: @tma_prefetch_5d +llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) { + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false) + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true) + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr + + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false) + // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true) + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr + nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr + llvm.return +}