1616include "mlir/IR/EnumAttr.td"
1717include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
1818include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19+ include "mlir/Dialect/LLVMIR/NVVMTraits.td"
1920include "mlir/Interfaces/SideEffectInterfaces.td"
2021include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
2122include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -136,8 +137,10 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
136137 let assemblyFormat = "attr-dict `:` type($res)";
137138}
138139
139- class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
140- NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
140+ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
141+ NVVM_SpecialRegisterOp<mnemonic,
142+ !listconcat(traits,
143+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
141144 let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
142145 let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
143146 let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -167,14 +170,14 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
167170def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
168171def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
169172def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
170- def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
173+ def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid", [NVVMRequiresSM<20>] >;
171174def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
172175def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
173176def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
174177
175178//===----------------------------------------------------------------------===//
176179// Lane Mask Comparison Ops
177- def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
180+ def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq", [NVVMRequiresSM<20>] >;
178181def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
179182def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
180183def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
@@ -200,7 +203,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
200203
201204//===----------------------------------------------------------------------===//
202205// CTA Cluster index and range
203- def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
206+ def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>] >;
204207def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
205208def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
206209def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
@@ -210,7 +213,7 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
210213
211214//===----------------------------------------------------------------------===//
212215// CTA index and range within Cluster
213- def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
216+ def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>] >;
214217def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
215218def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
216219def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
@@ -269,7 +272,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
269272def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
270273
271274def NVVM_ReduxOp :
272- NVVM_Op<"redux.sync">,
275+ NVVM_Op<"redux.sync", [NVVMRequiresSM<80>] >,
273276 Results<(outs LLVM_Type:$res)>,
274277 Arguments<(ins LLVM_Type:$val,
275278 ReduxKindAttr:$kind,
@@ -2327,7 +2330,8 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
23272330// NVVM Wgmma Ops
23282331//===----------------------------------------------------------------------===//
23292332
2330- def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
2333+ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
2334+ [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
23312335 let arguments = (ins);
23322336 let description = [{
23332337 Enforce an ordering of register accesses between warpgroup level matrix
@@ -2341,8 +2345,8 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
23412345 }];
23422346}
23432347
2344- def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
2345- Arguments<(ins ) > {
2348+ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
2349+ [NVVMRequiresSM<90, /*ArchAccelerated*/"true">] > {
23462350 let assemblyFormat = "attr-dict";
23472351 let description = [{
23482352 Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -2814,7 +2818,8 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
28142818// NVVM target attribute.
28152819//===----------------------------------------------------------------------===//
28162820
2817- def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target"> {
2821+ def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target",
2822+ [DeclareAttrInterfaceMethods<GPUTargetAttrVerifyInterface>]> {
28182823 let description = [{
28192824 GPU target attribute for controlling compilation of NVIDIA targets. All
28202825 parameters decay into default values if not present.
@@ -2862,6 +2867,9 @@ def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target"> {
28622867 bool hasFlag(StringRef flag) const;
28632868 bool hasFastMath() const;
28642869 bool hasFtz() const;
2870+ bool hasCmdOptions() const;
2871+ std::optional<mlir::NamedAttribute> getCmdOptions() const;
2872+ LogicalResult verifyTarget(Operation *gpuModule);
28652873 }];
28662874 let extraClassDefinition = [{
28672875 bool $cppClass::hasFlag(StringRef flag) const {
0 commit comments