diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td index 6d5fd01499121..018821f16c3a2 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td @@ -55,6 +55,20 @@ def GPUTargetAttrInterface : AttrInterface<"TargetAttrInterface"> { ]; } +def GPUTargetAttrVerifyInterface : AttrInterface<"TargetAttrVerifyInterface"> { + let description = [{ + Interface for GPU target attributes that verify the target attribute + of a given GPU module. + }]; + let cppNamespace = "::mlir::gpu"; + let methods = [ + InterfaceMethod<[{ + Verifies that the target attribute is valid for the given GPU module. + }], "::mlir::LogicalResult", "verifyTarget", + (ins "::mlir::Operation *":$module)> + ]; +} + def GPUTargetAttr : ConfinedAttr]> { let description = [{ diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 68095b7bf5c59..8d83d02e27c33 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1460,6 +1460,8 @@ def GPU_GPUModuleOp : GPU_Op<"module", [ /// Sets the targets of the module. void setTargets(ArrayRef targets); }]; + + let hasVerifier = 1; } def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt index 759de745440c2..9c5bbae1022f7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -54,6 +54,12 @@ mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen) add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen) +set(LLVM_TARGET_DEFINITIONS NVVMRequiresSMTraits.td) +mlir_tablegen(NVVMRequiresSMTraits.h.inc -gen-op-interface-decls) +mlir_tablegen(NVVMRequiresSMTraits.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRNVVMRequiresSMTraitsIncGen) +add_dependencies(mlir-headers MLIRNVVMRequiresSMTraitsIncGen) + add_mlir_dialect(NVVMOps nvvm) add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm) set(LLVM_TARGET_DEFINITIONS NVVMOps.td) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h index f1eae15d6bf18..bb2da40ae1cbe 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -15,8 +15,10 @@ #define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/InferIntRangeInterface.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 654aff71f25be..1a0fd20baa2cd 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -16,6 +16,7 @@ include "mlir/IR/EnumAttr.td" include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" @@ -138,8 +139,10 @@ class NVVM_SpecialRegisterOp traits = []> : let assemblyFormat = "attr-dict `:` type($res)"; } -class NVVM_SpecialRangeableRegisterOp : - NVVM_SpecialRegisterOp]> { +class NVVM_SpecialRangeableRegisterOp traits = []> : + NVVM_SpecialRegisterOp])> { let arguments = (ins OptionalAttr:$range); let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda; @@ -202,7 +205,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">; //===----------------------------------------------------------------------===// // CTA Cluster index and range -def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">; +def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>; def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">; def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">; def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">; @@ -212,16 +215,16 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster //===----------------------------------------------------------------------===// // CTA index and range within Cluster -def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">; -def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">; -def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">; -def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">; -def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">; +def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>; +def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>; +def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>; +def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>; +def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>; def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">; //===----------------------------------------------------------------------===// // CTA index and across Cluster dimensions -def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">; +def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>; def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">; //===----------------------------------------------------------------------===// @@ -273,7 +276,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind", def ReduxKindAttr : EnumAttr; def NVVM_ReduxOp : - NVVM_Op<"redux.sync">, + NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]>, Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_Type:$val, ReduxKindAttr:$kind, @@ -322,7 +325,7 @@ def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">, } /// mbarrier.init instruction with shared pointer type -def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">, +def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared", [NVVMRequiresSM<80>, DeclareOpInterfaceMethods]>, Arguments<(ins LLVM_PointerShared:$addr, I32:$count, PtxPredicate:$predicate)> { string llvmBuilder = [{ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count}); @@ -544,7 +547,7 @@ def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> { let assemblyFormat = "attr-dict"; } -def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> { +def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed", [NVVMRequiresSM<90>]> { let arguments = (ins OptionalAttr:$aligned); let summary = "Cluster Barrier Relaxed Arrive Op"; @@ -570,7 +573,7 @@ def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> { let assemblyFormat = "attr-dict"; } -def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait"> { +def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait", [NVVMRequiresSM<90>]> { let arguments = (ins OptionalAttr:$aligned); let summary = "Cluster Barrier Wait Op"; @@ -775,7 +778,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind", def ShflKindAttr : EnumAttr; def NVVM_ShflOp : - NVVM_Op<"shfl.sync">, + NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>, Results<(outs LLVM_Type:$res)>, Arguments<(ins I32:$thread_mask, LLVM_Type:$val, @@ -2114,7 +2117,7 @@ def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">, }]; } -def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">, +def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group", [NVVMRequiresSM<90>]>, Arguments<(ins ConfinedAttr]>:$group, OptionalAttr:$read)> { @@ -2144,7 +2147,7 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">, def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods, - AttrSizedOperandSegments]>, + AttrSizedOperandSegments, NVVMRequiresSM<90>]>, Arguments<(ins LLVM_PointerShared:$dstMem, LLVM_AnyPointer:$tmaDescriptor, Variadic:$coordinates, @@ -2581,7 +2584,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp : // NVVM Wgmma Ops //===----------------------------------------------------------------------===// -def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> { +def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>]> { let arguments = (ins); let description = [{ Enforce an ordering of register accesses between warpgroup level matrix @@ -2595,8 +2598,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> { }]; } -def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">, - Arguments<(ins )> { +def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<[90]>]> { let assemblyFormat = "attr-dict"; let description = [{ Commits all prior uncommitted warpgroup level matrix multiplication operations. @@ -2608,7 +2610,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">, }]; } -def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{ +def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<[90]>]> { let arguments = (ins I64Attr:$group); let assemblyFormat = "attr-dict $group"; let description = [{ @@ -2804,7 +2806,7 @@ def NVVM_GriddepcontrolLaunchDependentsOp def NVVM_MapaOp: NVVM_Op<"mapa", [TypesMatchWith<"`res` and `a` should have the same type", - "a", "res", "$_self">]> { + "a", "res", "$_self">, NVVMRequiresSM<90>]> { let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res); let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b); @@ -2971,7 +2973,7 @@ def Tcgen05WaitKindAttr : let assemblyFormat = "`<` $value `>`"; } -def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> { +def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]> { let summary = "Tcgen05 alloc operation"; let description = [{ The `tcgen05.alloc` Op allocates tensor core memory for @@ -3001,7 +3003,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> { }]; } -def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> { +def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 101]>]> { let summary = "Tcgen05 dealloc operation"; let description = [{ The `tcgen05.dealloc` Op de-allocates the tensor core memory @@ -3029,7 +3031,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> { }]; } -def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit"> { +def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMa<[100, 101]>]> { let summary = "Tcgen05 Op to relinquish the right to allocate"; let description = [{ The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA @@ -3052,7 +3054,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm }]; } -def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> { +def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>]> { let summary = "Tcgen05 fence operations"; let description = [{ The `tcgen05.fence` orders all prior async tcgen05 operations @@ -3074,7 +3076,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> { }]; } -def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> { +def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>]> { let summary = "Tcgen05 wait operations"; let description = [{ The `tcgen05.wait` causes the executing thread to block until @@ -3096,7 +3098,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> { }]; } -def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> { +def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]>]> { let summary = "Tcgen05 commit operations"; let description = [{ The `tcgen05.commit` makes the mbarrier object, specified by @@ -3134,7 +3136,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> { }]; } -def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> { +def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 103]>]> { let summary = "Tcgen05 shift operation"; let description = [{ The `tcgen05.shift` is an asynchronous instruction which initiates @@ -3200,7 +3202,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr { +def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> { let summary = "Tcgen05 copy operation"; let description = [{ Instruction tcgen05.cp initiates an asynchronous copy operation from @@ -3270,7 +3272,7 @@ def Tcgen05LdStShapeAttr: EnumAttr { +def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> { let summary = "tensor memory load instructions"; let arguments = (ins // Attributes @@ -3360,7 +3362,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> { // NVVM tcgen05.st Op //===----------------------------------------------------------------------===// -def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> { +def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> { let summary = "tensor memory store instructions"; let arguments = (ins // Attributes @@ -3512,7 +3514,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> { // NVVM target attribute. //===----------------------------------------------------------------------===// -def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> { +def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target", + [DeclareAttrInterfaceMethods]> { let description = [{ GPU target attribute for controlling compilation of NVIDIA targets. All parameters decay into default values if not present. @@ -3539,10 +3542,11 @@ def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> { StringRefParameter<"Target chip.", "\"sm_50\"">:$chip, StringRefParameter<"Target chip features.", "\"+ptx60\"">:$features, OptionalParameter<"DictionaryAttr", "Target specific flags.">:$flags, - OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link + OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link, + DefaultValuedParameter<"bool", "true", "Perform SM version check on Ops.">:$verifyTarget ); let assemblyFormat = [{ - (`<` struct($O, $triple, $chip, $features, $flags, $link)^ `>`)? + (`<` struct($O, $triple, $chip, $features, $flags, $link, $verifyTarget)^ `>`)? }]; let builders = [ AttrBuilder<(ins CArg<"int", "2">:$optLevel, @@ -3550,8 +3554,9 @@ def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> { CArg<"StringRef", "\"sm_50\"">:$chip, CArg<"StringRef", "\"+ptx60\"">:$features, CArg<"DictionaryAttr", "nullptr">:$targetFlags, - CArg<"ArrayAttr", "nullptr">:$linkFiles), [{ - return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles); + CArg<"ArrayAttr", "nullptr">:$linkFiles, + CArg<"bool", "true">:$verifyTarget), [{ + return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles, verifyTarget); }]> ]; let skipDefaultBuilders = 1; @@ -3562,6 +3567,7 @@ def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> { bool hasFtz() const; bool hasCmdOptions() const; std::optional getCmdOptions() const; + LogicalResult verifyTarget(Operation *gpuModule); }]; let extraClassDefinition = [{ bool $cppClass::hasFlag(StringRef flag) const { diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h new file mode 100644 index 0000000000000..36fcaee8ec3a2 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h @@ -0,0 +1,117 @@ +//===--- NVVMRequiresSMTraits.h - NVVM Requires SM Traits -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines op traits for the NVVM Dialect in MLIR +// +//===----------------------------------------------------------------------===// + +#ifndef NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_ +#define NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StorageUniquerSupport.h" +#include "llvm/ADT/StringExtras.h" + +namespace mlir { + +namespace NVVM { + +// Struct to store and check compatibility of SM versions. +struct NVVMCheckSMVersion { + // Set to true if the SM version is accelerated (e.g., sm_90a). + bool archAccelerated; + + // List of SM versions. + // Typically only has one version except for cases where multiple + // arch-accelerated versions are supported. + // For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a. + llvm::SmallVector smVersionList; + + template + NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions) + : archAccelerated(archAccelerated), smVersionList({smVersions...}) { + assert((archAccelerated || smVersionList.size() == 1) && + "non arch-accelerated SM version list must be a single version!"); + } + + bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const { + assert(targetSM.smVersionList.size() == 1 && + "target SM version list must be a single version!"); + + if (archAccelerated) { + if (!targetSM.archAccelerated) + return false; + + for (auto version : smVersionList) { + if (version == targetSM.smVersionList[0]) + return true; + } + } else { + return targetSM.smVersionList[0] >= smVersionList[0]; + } + + return false; + } + + bool isMinimumSMVersion() const { return smVersionList[0] >= 20; } + + // Parses an SM version string and returns an equivalent NVVMCheckSMVersion + // object. + static const NVVMCheckSMVersion + getTargetSMVersionFromStr(StringRef smVersionString) { + bool isAA = smVersionString.back() == 'a'; + + int smVersionInt; + smVersionString.drop_front(3) + .take_while([](char c) { return llvm::isDigit(c); }) + .getAsInteger(10, smVersionInt); + + return NVVMCheckSMVersion(isAA, smVersionInt); + } +}; + +} // namespace NVVM +} // namespace mlir + +#include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h.inc" + +namespace mlir { + +namespace OpTrait { + +template +class NVVMRequiresSM { +public: + template + class Impl + : public OpTrait::TraitBase::Impl>, + public mlir::NVVM::RequiresSMInterface::Trait { + public: + const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const { + return NVVM::NVVMCheckSMVersion(false, MinVersion); + } + }; +}; + +template +class NVVMRequiresSMa { +public: + template + class Impl : public OpTrait::TraitBase::Impl>, + public mlir::NVVM::RequiresSMInterface::Trait { + public: + const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const { + return NVVM::NVVMCheckSMVersion(true, SMVersions...); + } + }; +}; + +} // namespace OpTrait +} // namespace mlir +#endif // NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_ diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td new file mode 100644 index 0000000000000..34c0d6b78d5b2 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td @@ -0,0 +1,47 @@ +//===-- NVVMRequiresSMTraits.td - NVVM Requires SM Traits --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines traits for the NVVM Dialect in MLIR +// +//===----------------------------------------------------------------------===// + +#ifndef NVVM_REQUIRES_SM_TRAITS +#define NVVM_REQUIRES_SM_TRAITS + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +// Interface for NVVM Ops with the NVVMRequiresSM parametric trait +def RequiresSMInterface: OpInterface<"RequiresSMInterface"> { + let cppNamespace = "::mlir::NVVM"; + let methods = [ + InterfaceMethod< + "Get the SM version required by the op from the trait", + "const mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion" + > + ]; +} + +// OP requires a specified minimum SM value or higher; +// it is not architecture-specific. +class NVVMRequiresSM : + ParamNativeOpTrait<"NVVMRequiresSM", !cast(minVersion)>; + +class StrJoin str_list> { + string ret = !foldl("", str_list, a, b, + !if(!eq(a, ""), b, !if(!eq(b, ""), a, !strconcat(a, sep, b)))); +} + +// OP requires an exact SM match along with +// architecture acceleration. +class NVVMRequiresSMa smVersions> : + ParamNativeOpTrait<"NVVMRequiresSMa", + StrJoin<",", !foreach(vers, smVersions, + !cast(vers))>.ret>; + +#endif //NVVM_REQUIRES_SM_TRAITS diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 84e3071946f59..39f626b558294 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1790,6 +1790,22 @@ void GPUModuleOp::setTargets(ArrayRef targets) { targetsAttr = ArrayAttr::get(getContext(), targetsVector); } +LogicalResult GPUModuleOp::verify() { + auto targets = getOperation()->getAttrOfType("targets"); + + if (!targets) + return success(); + + for (auto target : targets) { + if (auto verifyTargetAttr = + llvm::dyn_cast(target)) { + if (verifyTargetAttr.verifyTarget(getOperation()).failed()) + return failure(); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // GPUBinaryOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index c9a3b97294562..d83fd3800eb91 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRLLVMDialect add_mlir_dialect_library(MLIRNVVMDialect IR/NVVMDialect.cpp IR/BasicPtxBuilderInterface.cpp + IR/NVVMRequiresSMTraits.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR @@ -51,6 +52,7 @@ add_mlir_dialect_library(MLIRNVVMDialect MLIRNVVMOpsIncGen MLIRNVVMConversionsIncGen MLIRBasicPtxBuilderInterfaceIncGen + MLIRNVVMRequiresSMTraitsIncGen intrinsics_gen LINK_COMPONENTS @@ -60,6 +62,7 @@ add_mlir_dialect_library(MLIRNVVMDialect LINK_LIBS PUBLIC MLIRIR MLIRLLVMDialect + MLIRGPUDialect MLIRSideEffectInterfaces MLIRInferIntRangeInterface ) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 1ea3f96fa75f5..3523ae5381866 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -18,6 +18,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -1710,7 +1711,7 @@ LogicalResult NVVMTargetAttr::verify(function_ref emitError, int optLevel, StringRef triple, StringRef chip, StringRef features, DictionaryAttr flags, - ArrayAttr files) { + ArrayAttr files, bool verifyTarget) { if (optLevel < 0 || optLevel > 3) { emitError() << "The optimization level must be a number between 0 and 3."; return failure(); @@ -1732,6 +1733,37 @@ NVVMTargetAttr::verify(function_ref emitError, return success(); } +LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) { + if (!getVerifyTarget()) + return success(); + + auto gpuModuleOp = llvm::dyn_cast(gpuModule); + if (!gpuModuleOp) { + return emitError(gpuModule->getLoc(), + "NVVM target attribute must be attached to a GPU module"); + } + + const NVVMCheckSMVersion targetSMVersion = + NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip()); + if (!targetSMVersion.isMinimumSMVersion()) { + return emitError(gpuModule->getLoc(), + "Minimum NVVM target SM version is sm_20"); + } + + gpuModuleOp->walk([&](Operation *op) { + if (auto reqOp = llvm::dyn_cast(op)) { + const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion(); + if (!requirement.isCompatibleWith(targetSMVersion)) { + op->emitOpError() << "is not supported on " << getChip(); + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMRequiresSMTraits.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMRequiresSMTraits.cpp new file mode 100644 index 0000000000000..68bb3c6827bb3 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMRequiresSMTraits.cpp @@ -0,0 +1,15 @@ +//===--- NVVMRequiresSMTraits.cpp - NVVM Requires SM Traits -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines op traits for the NVVM Dialect in MLIR +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h" + +#include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.cpp.inc" diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir new file mode 100644 index 0000000000000..e469d336dc1ae --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir @@ -0,0 +1,85 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// Just check these don't emit errors. +gpu.module @check_valid_SM_exact [#nvvm.target] { + test.nvvm_requires_sm_80 +} + +gpu.module @check_valid_SM_greater_1 [#nvvm.target] { + test.nvvm_requires_sm_80 +} + +gpu.module @check_valid_SM_greater_2 [#nvvm.target] { + test.nvvm_requires_sm_80 +} + +gpu.module @check_valid_SM_arch_acc_1 [#nvvm.target] { + test.nvvm_requires_sm_90a +} + +gpu.module @check_valid_SM_arch_acc_2 [#nvvm.target] { + test.nvvm_requires_sm_80 +} + +gpu.module @check_valid_SM_arch_acc_multi_1 [#nvvm.target] { + test.nvvm_requires_sm_90a_or_sm_100a +} + +gpu.module @check_valid_SM_arch_acc_multi_2 [#nvvm.target] { + test.nvvm_requires_sm_90a_or_sm_100a +} + + +gpu.module @disable_verify_target1 [#nvvm.target] { + test.nvvm_requires_sm_90a +} + +gpu.module @disable_verify_target2 [#nvvm.target] { + test.nvvm_requires_sm_80 +} + +gpu.module @disable_verify_target3 [#nvvm.target] { + test.nvvm_requires_sm_90a_or_sm_100a +} + +// ----- + +gpu.module @check_invalid_SM_lesser_1 [#nvvm.target] { + // expected-error @below {{is not supported on sm_70}} + test.nvvm_requires_sm_80 +} + +// ----- + +gpu.module @check_invalid_SM_lesser_2 [#nvvm.target] { + // expected-error @below {{is not supported on sm_75}} + test.nvvm_requires_sm_80 +} + +// ----- + +gpu.module @check_invalid_SM_arch_acc_1 [#nvvm.target] { + // expected-error @below {{is not supported on sm_90}} + test.nvvm_requires_sm_90a +} + +// ----- + +gpu.module @check_invalid_SM_arch_acc_2 [#nvvm.target] { + // expected-error @below {{is not supported on sm_80}} + test.nvvm_requires_sm_90a +} + +// ----- + +gpu.module @check_invalid_SM_arch_acc_multi_1 [#nvvm.target] { + // expected-error @below {{is not supported on sm_80}} + test.nvvm_requires_sm_90a_or_sm_100a +} + +// ----- + +gpu.module @check_invalid_SM_arch_acc_multi_2 [#nvvm.target] { + // expected-error @below {{is not supported on sm_90}} + test.nvvm_requires_sm_90a_or_sm_100a +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index d2181cea0ecf9..f099d01abd31a 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -86,6 +86,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC MLIRLinalgTransforms MLIRPtrDialect MLIRLLVMDialect + MLIRNVVMDialect MLIRPass MLIRReduce MLIRTensorDialect diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h index f070c3bedd92c..9d5b225b219cd 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.h +++ b/mlir/test/lib/Dialect/Test/TestOps.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Traits.h" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 43a0bdaf86cf3..cdc0f393b4761 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -13,6 +13,7 @@ include "TestDialect.td" include "TestInterfaces.td" include "mlir/Dialect/DLTI/DLTIBase.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" +include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td" include "mlir/IR/EnumAttr.td" include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/IR/OpBase.td" @@ -2800,6 +2801,28 @@ def TestLinalgFillOp : }]; } +//===----------------------------------------------------------------------===// +// Test NVVM RequiresSM trait. +//===----------------------------------------------------------------------===// + +def TestNVVMRequiresSMOp : + TEST_Op<"nvvm_requires_sm_80", [NVVMRequiresSM<80>]> { + let arguments = (ins ); + let assemblyFormat = "attr-dict"; +} + +def TestNVVMRequiresSMArchCondOp : + TEST_Op<"nvvm_requires_sm_90a", [NVVMRequiresSMa<[90]>]> { + let arguments = (ins ); + let assemblyFormat = "attr-dict"; +} + +def TestNVVMRequirestSMArchCondMultiOp : + TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMa<[90, 100]>]> { + let arguments = (ins ); + let assemblyFormat = "attr-dict"; +} + //===----------------------------------------------------------------------===// // Test Ops with Default-Valued String Attributes //===----------------------------------------------------------------------===// diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 5f7ed5724e3f2..eb33acf4b0e0f 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5774,15 +5774,18 @@ cc_library( name = "NVVMDialect", srcs = [ "lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp", + "lib/Dialect/LLVMIR/IR/NVVMRequiresSMTraits.cpp", "lib/Dialect/LLVMIR/IR/NVVMDialect.cpp", ], hdrs = [ "include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h", + "include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h", "include/mlir/Dialect/LLVMIR/NVVMDialect.h", - ], +], includes = ["include"], deps = [ ":BasicPtxBuilderIntGen", + ":NVVMRequiresSMTraitsIntGen", ":BytecodeOpInterface", ":ConvertToLLVMInterface", ":DialectUtils", @@ -5855,12 +5858,20 @@ td_library( ], ) +td_library( + name = "NVVMRequiresSMTraitsIntTdFiles", + srcs = ["include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"] +) + td_library( name = "NVVMOpsTdFiles", srcs = ["include/mlir/Dialect/LLVMIR/NVVMOps.td"], includes = ["include"], deps = [ ":BasicPtxBuilderIntTdFiles", + ":NVVMRequiresSMTraitsIntTdFiles", ":GPUOpsTdFiles", ":LLVMOpsTdFiles", ":OpBaseTdFiles", @@ -5887,6 +5898,23 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "NVVMRequiresSMTraitsIntGen", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td", + deps = [":NVVMRequiresSMTraitsIntTdFiles"], +) + gentbl_cc_library( name = "NVVMOpsIncGen", tbl_outs = {