diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h index f1eae15d6bf18..c7b0724349940 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -49,6 +49,11 @@ enum NVVMMemorySpace { kSharedClusterMemorySpace = 7, }; +/// A pair type of LLVM's Intrinsic ID and args (which are llvm values). +/// This type is returned by the getIntrinsicIDAndArgs() methods. +using IDArgPair = + std::pair>; + /// Return the element type and number of elements associated with a wmma matrix /// of given chracteristics. This matches the logic in IntrinsicsNVVM.td /// WMMA_REGS structure. diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index a8e7dcb54ac20..a3ea5c3201ba1 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2599,15 +2599,37 @@ def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp : } def NVVM_CpAsyncBulkSharedCTAToGlobalOp : - NVVM_Op<"cp.async.bulk.global.shared.cta"> { + NVVM_Op<"cp.async.bulk.global.shared.cta", [AttrSizedOperandSegments]> { let summary = "Async bulk copy from Shared CTA memory to Global memory"; let description = [{ Initiates an asynchronous copy operation from Shared CTA memory to - global memory. + global memory. The 32-bit operand `size` specifies the amount of + memory to be copied, in terms of number of bytes. `size` must be a + multiple of 16. The `l2CacheHint` operand is optional, and it is used + to specify cache eviction policy that may be used during the memory + access. The `byteMask` operand is optional. The i-th bit in the 16-bit + wide `byteMask` specifies whether the i-th byte of each 16-byte wide + chunk of source data is copied to the destination. If the bit is set, + the byte is copied. + + Example: + ```mlir + nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size + : !llvm.ptr<1>, !llvm.ptr<3> + + // with l2_cache_hint + nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch + : !llvm.ptr<1>, !llvm.ptr<3> + + // with byte_mask + nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask + : !llvm.ptr<1>, !llvm.ptr<3> + + // with both l2_cache_hint and byte_mask + nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask + : !llvm.ptr<1>, !llvm.ptr<3> + ``` - The `l2CacheHint` operand is optional, and it is used to specify cache - eviction policy that may be used during the memory access. - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk) }]; @@ -2615,35 +2637,25 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp : LLVM_PointerGlobal:$dstMem, LLVM_PointerShared:$srcMem, I32:$size, - Optional:$l2CacheHint); + Optional:$l2CacheHint, + Optional:$byteMask); let assemblyFormat = [{ $dstMem `,` $srcMem `,` $size (`l2_cache_hint` `=` $l2CacheHint^ )? - attr-dict `:` type($dstMem) `,` type($srcMem) + (`byte_mask` `=` $byteMask^ )? + attr-dict `:` type($dstMem) `,` type($srcMem) }]; + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; string llvmBuilder = [{ - // Arguments to the intrinsic: - // dst, src, size, cache_hint, - // Flag for cache_hint - // - llvm::SmallVector translatedOperands; - translatedOperands.push_back($dstMem); - translatedOperands.push_back($srcMem); - translatedOperands.push_back($size); - - // Cachehint, if available - llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext(); - auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0); - bool isCacheHint = op.getL2CacheHint() ? true : false; - translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused); - - // Flag argument for cachehint - translatedOperands.push_back(builder.getInt1(isCacheHint)); - - createIntrinsicCall(builder, - llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands); + auto [id, args] = NVVM::CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, id, args); }]; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 1ea3f96fa75f5..eb19762e1a2fa 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1253,6 +1253,34 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, return id; } +mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + llvm::SmallVector args; + llvm::Intrinsic::ID id = + llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global; + + // Fill the Intrinsic Args + args.push_back(mt.lookupValue(thisOp.getDstMem())); + args.push_back(mt.lookupValue(thisOp.getSrcMem())); + args.push_back(mt.lookupValue(thisOp.getSize())); + + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast(cacheHint); + llvm::Value *i64Unused = + llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); + args.push_back(builder.getInt1(hasCacheHint)); + + // Choose the bytemask variant + if (mlir::Value byteMask = thisOp.getByteMask()) { + args.push_back(mt.lookupValue(byteMask)); + id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask; + } + + return {id, std::move(args)}; +} + llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, bool isIm2Col) { switch (tensorDims) { diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir index 39b703d9a9677..0daf24536a672 100644 --- a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir @@ -26,9 +26,19 @@ llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr // CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64) { // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 %[[CH:.*]], i1 true) + // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true) nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size : !llvm.ptr<1>, !llvm.ptr<3> nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch : !llvm.ptr<1>, !llvm.ptr<3> llvm.return } + +// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask +llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64, %mask : i16) { + // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false, i16 %[[MASK:.*]]) + // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true, i16 %[[MASK]]) + nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3> + + nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3> + llvm.return +}