diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h index fc38a3fb2d387..6137bb087c576 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -36,12 +36,16 @@ constexpr int kSharedMemoryAlignmentBit = 128; /// NVVM memory space identifiers. enum NVVMMemorySpace { + /// Generic memory space identifier. + kGenericMemorySpace = 0, /// Global memory space identifier. kGlobalMemorySpace = 1, /// Shared memory space identifier. kSharedMemorySpace = 3, /// Constant memory space identifier. kConstantMemorySpace = 4, + /// Local memory space identifier. + kLocalMemorySpace = 5, /// Tensor memory space identifier. /// Tensor memory is available only in arch-accelerated /// variants from sm100 onwards. diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 596a584d485ed..026c1fae0eb89 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -25,6 +25,7 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td" def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>; def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; +def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>; def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>; def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>; @@ -118,6 +119,25 @@ class NVVM_Attr traits = []> let mnemonic = attrMnemonic; } +// Cache Eviction Priority enum definitions +def EvictNormal : I32EnumCase<"EvictNormal", 0, "evict_normal">; +def EvictFirst : I32EnumCase<"EvictFirst", 1, "evict_first">; +def EvictLast : I32EnumCase<"EvictLast", 2, "evict_last">; +def EvictUnchanged : I32EnumCase<"EvictUnchanged", 3, "evict_unchanged">; +def NoAllocate : I32EnumCase<"NoAllocate", 4, "no_allocate">; + +def CacheEvictionPriority : I32Enum<"CacheEvictionPriority", + "NVVM Cache Eviction Priority", + [EvictNormal, EvictFirst, EvictLast, + EvictUnchanged, NoAllocate]> { + let cppNamespace = "::mlir::NVVM"; +} + +def CacheEvictionPriorityAttr : EnumAttr { + let assemblyFormat = "$value"; +} + //===----------------------------------------------------------------------===// // NVVM intrinsic operations //===----------------------------------------------------------------------===// @@ -2333,6 +2353,60 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// NVVM Prefetch Op +//===----------------------------------------------------------------------===// + +def PrefetchCacheLevelL1 : I32EnumCase<"L1", 0, "L1">; +def PrefetchCacheLevelL2 : I32EnumCase<"L2", 1, "L2">; + +def PrefetchCacheLevel : I32Enum<"PrefetchCacheLevel", + "NVVM Prefetch Cache Level", + [PrefetchCacheLevelL1, PrefetchCacheLevelL2]> { + let cppNamespace = "::mlir::NVVM"; +} + +def PrefetchCacheLevelAttr : EnumAttr { + let assemblyFormat = "$value"; +} + +def NVVM_PrefetchOp : NVVM_Op<"prefetch"> { + let summary = "Brings the cache line containing an address into the specified cache level"; + let description = [{ + Operand `addr` can be a global, local or generic address pointer. No + operation is performed if `addr` maps to a `shared` memory location. + + The `cacheLevel` attribute specifies the cache level to which the cache line + containing the specified address is brought. + + `uniform` can be specified after the `cacheLevel` to indicate that the + prefetch is performed to the specified uniform cache level. If `uniform` is + specified, `addr` must be a generic address pointer and no operation is + performed if `addr` maps to a `const`, `local`, or `shared` memory location. + + The `evictPriority` attribute is optional and specifies the cache eviction + priority when `cacheLevel` is L2. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu) + }]; + let arguments = (ins PrefetchCacheLevelAttr:$cacheLevel, + UnitAttr:$uniform, + AnyTypeOf<[LLVM_PointerGlobal, + LLVM_PointerLocal, + LLVM_PointerGeneric]>:$addr, + OptionalAttr:$evictPriority); + let assemblyFormat = "`level` `=` $cacheLevel (`uniform` $uniform^)? `,` $addr (`,` `evict_priority` `=` $evictPriority^)? attr-dict `:` type($addr)"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static llvm::Intrinsic::ID getIntrinsicID(NVVM::PrefetchOp &op); + }]; + let llvmBuilder = [{ + auto intId = NVVM::PrefetchOp::getIntrinsicID(op); + createIntrinsicCall(builder, intId, $addr); + }]; +} + def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap", [DeclareOpInterfaceMethods]>, Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> { diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index a77ff1e32dc23..aaf6b0593c2e6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1205,6 +1205,42 @@ LogicalResult NVVM::VoteSyncOp::verify() { return success(); } +LogicalResult NVVM::PrefetchOp::verify() { + using MemSpace = NVVM::NVVMMemorySpace; + using CacheLevel = NVVM::PrefetchCacheLevel; + + unsigned addressSpace = + llvm::cast(getAddr().getType()).getAddressSpace(); + std::optional evictPriority = getEvictPriority(); + + if (getUniform()) { + if (getCacheLevel() != CacheLevel::L1) + return emitOpError("unsupported cache level, the only supported uniform " + "cache level is L1"); + + if (addressSpace != MemSpace::kGenericMemorySpace) + return emitOpError( + "prefetch to uniform cache requires a generic pointer"); + } + + if (evictPriority) { + if (getCacheLevel() != CacheLevel::L2) + return emitOpError( + "cache eviction priority supported only for cache level L2"); + + if (addressSpace != MemSpace::kGlobalMemorySpace) + return emitOpError("cache eviction priority requires a global pointer"); + + if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal && + *evictPriority != NVVM::CacheEvictionPriority::EvictLast) + return emitOpError( + "unsupported cache eviction priority, only evict_last and " + "evict_normal are supported"); + } + + return success(); +} + /// Packs the given `field` into the `result`. /// The `result` is 64-bits and each `field` can be 32-bits or narrower. static llvm::Value * @@ -1734,6 +1770,48 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs( return {ids[type], args}; } +llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) { + using MemSpace = NVVM::NVVMMemorySpace; + using CacheLevel = NVVM::PrefetchCacheLevel; + + NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel(); + std::optional evictPriority = + op.getEvictPriority(); + unsigned addressSpace = + llvm::cast(op.getAddr().getType()) + .getAddressSpace(); + + if (op.getUniform() && cacheLevel == CacheLevel::L1) + return llvm::Intrinsic::nvvm_prefetchu_L1; + + if (evictPriority && cacheLevel == CacheLevel::L2) { + switch (*evictPriority) { + case NVVM::CacheEvictionPriority::EvictLast: + return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last; + case NVVM::CacheEvictionPriority::EvictNormal: + return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal; + default: + llvm_unreachable("Invalid cache eviction priority"); + } + } + + switch (addressSpace) { + case MemSpace::kGenericMemorySpace: + return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1 + : llvm::Intrinsic::nvvm_prefetch_L2; + case MemSpace::kGlobalMemorySpace: + return cacheLevel == CacheLevel::L1 + ? llvm::Intrinsic::nvvm_prefetch_global_L1 + : llvm::Intrinsic::nvvm_prefetch_global_L2; + case MemSpace::kLocalMemorySpace: + return cacheLevel == CacheLevel::L1 + ? llvm::Intrinsic::nvvm_prefetch_local_L1 + : llvm::Intrinsic::nvvm_prefetch_local_L2; + default: + llvm_unreachable("Invalid pointer address space"); + } +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index a02d33f50e0d2..c7fa41c98ac92 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -596,6 +596,29 @@ func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c: return } +// CHECK-LABEL: @prefetch +func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) { + // CHECK: nvvm.prefetch level = L1, %{{.*}} + nvvm.prefetch level = L1, %gen_ptr : !llvm.ptr<0> + // CHECK: nvvm.prefetch level = L1, %{{.*}} + nvvm.prefetch level = L1, %local_ptr : !llvm.ptr<5> + // CHECK: nvvm.prefetch level = L1, %{{.*}} + nvvm.prefetch level = L1, %global_ptr : !llvm.ptr<1> + // CHECK: nvvm.prefetch level = L2, %{{.*}} + nvvm.prefetch level = L2, %gen_ptr : !llvm.ptr<0> + // CHECK: nvvm.prefetch level = L2, %{{.*}} + nvvm.prefetch level = L2, %local_ptr : !llvm.ptr<5> + // CHECK: nvvm.prefetch level = L2, %{{.*}} + nvvm.prefetch level = L2, %global_ptr : !llvm.ptr<1> + // CHECK: nvvm.prefetch level = L2, %{{.*}} + nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1> + // CHECK: nvvm.prefetch level = L2, %{{.*}} + nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1> + // CHECK: nvvm.prefetch level = L1 uniform, %{{.*}} + nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr + return +} + // ----- // Just check these don't emit errors. diff --git a/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir new file mode 100644 index 0000000000000..f38b7529a7233 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @prefetch_L1(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) { + // CHECK-LABEL: define void @prefetch_L1(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) { + // CHECK-NEXT: call void @llvm.nvvm.prefetch.L1(ptr %0) + // CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L1(ptr addrspace(5) %1) + // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L1(ptr addrspace(1) %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.prefetch level = L1, %gen_ptr : !llvm.ptr<0> + nvvm.prefetch level = L1, %local_ptr : !llvm.ptr<5> + nvvm.prefetch level = L1, %global_ptr : !llvm.ptr<1> + llvm.return +} + +llvm.func @prefetch_L2(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) { + // CHECK-LABEL: define void @prefetch_L2(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) { + // CHECK-NEXT: call void @llvm.nvvm.prefetch.L2(ptr %0) + // CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L2(ptr addrspace(5) %1) + // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2(ptr addrspace(1) %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.prefetch level = L2, %gen_ptr : !llvm.ptr<0> + nvvm.prefetch level = L2, %local_ptr : !llvm.ptr<5> + nvvm.prefetch level = L2, %global_ptr : !llvm.ptr<1> + llvm.return +} + +llvm.func @prefetch_L2_eviction_priority(%global_ptr: !llvm.ptr<1>) { + // CHECK-LABEL: define void @prefetch_L2_eviction_priority(ptr addrspace(1) %0) { + // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.last(ptr addrspace(1) %0) + // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.normal(ptr addrspace(1) %0) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1> + nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1> + llvm.return +} + +llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) { + // CHECK-LABEL: define void @prefetch_L1_uniform(ptr %0) { + // CHECK-NEXT: call void @llvm.nvvm.prefetchu.L1(ptr %0) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 3d63434f310bd..8c4f0aafd36a7 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -248,3 +248,67 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) { %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 llvm.return } + +// ----- + +llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) { + // expected-error @below {{cache eviction priority supported only for cache level L2}} + nvvm.prefetch level = L1, %global_ptr, evict_priority = evict_last : !llvm.ptr<1> + llvm.return +} + +// ----- + +llvm.func @nvvm_prefetch_L2_with_evict_last_invalid_addr_space(%local_ptr: !llvm.ptr<5>) { + // expected-error @below {{cache eviction priority requires a global pointer}} + nvvm.prefetch level = L2, %local_ptr, evict_priority = evict_last : !llvm.ptr<5> + llvm.return +} + +// ----- + +llvm.func @nvvm_prefetch_L2_with_evict_normal_invalid_addr_space(%local_ptr: !llvm.ptr<5>) { + // expected-error @below {{cache eviction priority requires a global pointer}} + nvvm.prefetch level = L2, %local_ptr, evict_priority = evict_normal : !llvm.ptr<5> + llvm.return +} + +// ----- + +llvm.func @nvvm_prefetch_L2_with_invalid_evict_first(%global_ptr: !llvm.ptr<1>) { + // expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}} + nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_first : !llvm.ptr<1> + llvm.return +} + +// ----- + +llvm.func @nvvm_prefetch_L2_with_invalid_evict_unchanged(%global_ptr: !llvm.ptr<1>) { + // expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}} + nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_unchanged : !llvm.ptr<1> + llvm.return +} + +// ----- + +llvm.func @nvvm_prefetch_L2_with_invalid_no_allocate(%global_ptr: !llvm.ptr<1>) { + // expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}} + nvvm.prefetch level = L2, %global_ptr, evict_priority = no_allocate : !llvm.ptr<1> + llvm.return +} + +// ----- + +llvm.func @nvvm_prefetch_uniform_with_L2(%gen_ptr: !llvm.ptr) { + // expected-error @below {{unsupported cache level, the only supported uniform cache level is L1}} + nvvm.prefetch level = L2 uniform, %gen_ptr : !llvm.ptr + llvm.return +} + +// ----- + +llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr<1>) { + // expected-error @below {{prefetch to uniform cache requires a generic pointer}} + nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1> + llvm.return +}