Skip to content

Conversation

@Wolfram70
Copy link
Contributor

@Wolfram70 Wolfram70 commented Feb 12, 2025

Motivation:
Currently, the NVVMOps are not verified against the supported SM architectures. This can manifest as an ISel failure in the NVPTX LLVM backend during CodeGen to PTX ISA. This PR addresses this issue by adding verifier checks for Target-SM architectures in the NVVM Dialect itself, thereby catching the errors early on.

Summary:

  • Parametric traits named NVVMRequiresSM and NVVMRequiresSMa are added to facilitate the version checks for typical and arch-accelerated versions respectively.
  • These traits can be attached to any NVVM Op to enable the checks for the particular Op. (example shown below)
  • An attribute interface called named TargetAttrVerifyInterface is added to the GPU dialect which any target attribute seeking to perform target-verification on the module can implement.
  • The checks are performed by the NVVMTargetAttr (implementing the TargetAttrVerifyInterface interface) when called from the GPU module verifier where it walks through the module and performs the checks for Ops with the NVVMRequiresSM traits.
  • A few Ops in NVVMOps.td have been updated to serve as examples.

Example Usage:

       def NVVM_ReduxOp : NVVM_Op<"redux.sync"> {...} 
 ----> def NVVM_ReduxOp : NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]> {...}

       def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {...}
 ----> def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>]> {...}

@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Srinivasa Ravi (Wolfram70)

Changes

This change adds the NVVMRequiresSM op trait to the NVVM dialect to allow tagging NVVM Ops with a minimum required SM version.

Summary:

  • A parametric trait named NVVMRequiresSM is added that facilitates the version checks.

  • An Op can participate in the version check by having this trait attached to its definition. (example shown below)

  • The target SM version for comparison is obtained from the NVVMTargetAttr attached to the parent GPU module.

  • A few Ops in NVVMOps.td have been updated to serve as examples:

    Ex. def NVVM_ReduxOp : NVVM_Op&lt;"redux.sync"&gt; {...}
    --> def NVVM_ReduxOp : NVVM_Op&lt;"redux.sync", [NVVMRequiresSM&lt;80&gt;]&gt; {...}


Full diff: https://github.com/llvm/llvm-project/pull/126886.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h (+1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+12-8)
  • (added) mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h (+91)
  • (added) mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td (+22)
  • (modified) mlir/lib/Dialect/LLVMIR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+31)
  • (added) mlir/test/Dialect/LLVMIR/nvvm-check-targetSM-disabled.mlir (+19)
  • (added) mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir (+46)
  • (modified) mlir/test/lib/Dialect/Test/CMakeLists.txt (+1)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.h (+1)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+17)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index a9270c6f52344..8702129f1fdef 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -17,6 +17,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index fe15a524ec3b5..b2eede7f797dd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -16,6 +16,7 @@
 include "mlir/IR/EnumAttr.td"
 include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/LLVMIR/NVVMTraits.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -136,8 +137,10 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
-class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
-  NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
+  NVVM_SpecialRegisterOp<mnemonic,
+    !listconcat(traits,
+      [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
   let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
   let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
   let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -167,14 +170,14 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
 def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
 def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
 def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
-def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
+def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid", [NVVMRequiresSM<20>]>;
 def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
 def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
 def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
 
 //===----------------------------------------------------------------------===//
 // Lane Mask Comparison Ops
-def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
+def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq", [NVVMRequiresSM<20>]>;
 def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
 def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
 def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
@@ -200,7 +203,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
 
 //===----------------------------------------------------------------------===//
 // CTA Cluster index and range
-def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
+def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
 def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
 def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
 def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
@@ -210,7 +213,7 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
 
 //===----------------------------------------------------------------------===//
 // CTA index and range within Cluster
-def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
+def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
 def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
 def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
 def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
@@ -269,7 +272,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
 def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
 
 def NVVM_ReduxOp :
-  NVVM_Op<"redux.sync">,
+  NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]>,
   Results<(outs LLVM_Type:$res)>,
   Arguments<(ins LLVM_Type:$val,
                  ReduxKindAttr:$kind,
@@ -2327,7 +2330,8 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
+                              [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
   let arguments = (ins);
   let description = [{
     Enforce an ordering of register accesses between warpgroup level matrix 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
new file mode 100644
index 0000000000000..40c2bcd50f235
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -0,0 +1,91 @@
+//===--- NVVMTraits.h - NVVM Traits -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines op traits for the NVVM Dialect in MLIR
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
+#define NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/StringExtras.h"
+
+namespace mlir {
+
+namespace NVVM {
+
+struct NVVMCheckSMVersion {
+  int ArchVersion;
+  bool ArchAccelerated;
+  std::string ArchString;
+
+  NVVMCheckSMVersion() {}
+  NVVMCheckSMVersion(StringRef SMVersion) : ArchString(SMVersion) {
+    parse(SMVersion);
+  }
+  NVVMCheckSMVersion(int ArchVersion, bool ArchAccelerated)
+      : ArchVersion(ArchVersion), ArchAccelerated(ArchAccelerated) {
+    ArchString = (llvm::Twine("sm_") + llvm::Twine(ArchVersion) +
+                  (ArchAccelerated ? "a" : "\0"))
+                     .str();
+  }
+
+  const StringRef getArchString() const { return ArchString; }
+
+  void parse(StringRef SMVersion) {
+    ArchAccelerated = (SMVersion[SMVersion.size() - 1] == 'a');
+    SMVersion.drop_front(3)
+        .take_while([](char c) { return llvm::isDigit(c); })
+        .getAsInteger(10, ArchVersion);
+  }
+
+  bool isCompatible(const NVVMCheckSMVersion &TargetSM) const {
+    // for arch-conditional SMs, they should exactly match to be valid
+    if (ArchAccelerated || TargetSM.ArchAccelerated)
+      return (*this) == TargetSM;
+
+    return ArchVersion <= TargetSM.ArchVersion;
+  }
+
+  bool operator==(const NVVMCheckSMVersion &Other) const {
+    return ArchVersion == Other.ArchVersion &&
+           ArchAccelerated == Other.ArchAccelerated;
+  }
+};
+
+llvm::SmallVector<NVVMCheckSMVersion> getTargetSMVersions(Operation *op);
+
+LogicalResult
+verifyOpSMRequirements(Operation *op,
+                       llvm::SmallVector<NVVMCheckSMVersion> TargetSMVersions,
+                       NVVMCheckSMVersion RequiredSMVersion);
+} // namespace NVVM
+
+namespace OpTrait {
+
+template <int Version, bool ArchAccelerated = false>
+class NVVMRequiresSM {
+public:
+  template <typename ConcreteOp>
+  class Impl : public OpTrait::TraitBase<
+                   ConcreteOp, NVVMRequiresSM<Version, ArchAccelerated>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      NVVM::NVVMCheckSMVersion RequiredSMVersion(Version, ArchAccelerated);
+      llvm::SmallVector<NVVM::NVVMCheckSMVersion> TargetSMVersions =
+          NVVM::getTargetSMVersions(op);
+
+      return NVVM::verifyOpSMRequirements(op, TargetSMVersions,
+                                          RequiredSMVersion);
+    }
+  };
+};
+} // namespace OpTrait
+} // namespace mlir
+#endif // NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
new file mode 100644
index 0000000000000..7b2b43e88dc57
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
@@ -0,0 +1,22 @@
+//===-- NVVMTraits.td - NVVM Traits ------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines traits for the NVVM Dialect in MLIR
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVM_TRAITS
+#define NVVM_TRAITS
+
+include "mlir/IR/OpBase.td"
+
+class NVVMRequiresSM<int Version, string ArchAccelerated = "false"> :
+  ParamNativeOpTrait<"NVVMRequiresSM",
+                    !cast<string>(Version) # "," # ArchAccelerated>;
+
+#endif //NVVM_TRAITS
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index c9a3b97294562..0d14dea3ca168 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMDialect
+  MLIRGPUDialect
   MLIRSideEffectInterfaces
   MLIRInferIntRangeInterface
   )
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 62f0c21338111..9315382727f89 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -18,6 +18,7 @@
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -1439,6 +1440,36 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Requires minimum target SM trait helper functions
+//===----------------------------------------------------------------------===//
+llvm::SmallVector<NVVMCheckSMVersion> NVVM::getTargetSMVersions(Operation *op) {
+  llvm::SmallVector<NVVMCheckSMVersion> TargetSMVersions;
+  gpu::GPUModuleOp GPUModule = op->getParentOfType<gpu::GPUModuleOp>();
+  if (GPUModule && GPUModule->hasAttr("targets")) {
+    ArrayAttr Targets = dyn_cast<ArrayAttr>(GPUModule->getAttr("targets"));
+    for (auto Target : Targets) {
+      if (auto NVVMTarget = dyn_cast<NVVMTargetAttr>(Target))
+        TargetSMVersions.push_back(NVVMCheckSMVersion(NVVMTarget.getChip()));
+    }
+  }
+  return TargetSMVersions;
+}
+
+// Helper function to verify the minimum SM requirement of an NVVM Op
+LogicalResult NVVM::verifyOpSMRequirements(
+    Operation *op, llvm::SmallVector<NVVMCheckSMVersion> TargetSMVersions,
+    NVVMCheckSMVersion RequiredSMVersion) {
+  for (auto TargetSMVersion : TargetSMVersions) {
+    if (!RequiredSMVersion.isCompatible(TargetSMVersion)) {
+      op->emitOpError() << "is not supported on "
+                        << TargetSMVersion.getArchString();
+      return failure();
+    }
+  }
+  return success();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
 
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM-disabled.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM-disabled.mlir
new file mode 100644
index 0000000000000..52dabd8c285f1
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM-disabled.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s --mlir-very-unsafe-disable-verifier-on-parsing  -verify-diagnostics
+
+// Just check these don't emit errors.
+
+gpu.module @check_invalid_disabled_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
+  test.nvvm_requires_sm_80
+}
+
+gpu.module @check_invalid_disabled_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {
+  test.nvvm_requires_sm_80
+}
+
+gpu.module @check_invalid_disabled_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
+  test.nvvm_requires_sm_90a
+}
+
+gpu.module @check_invalid_disabled_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
+  test.nvvm_requires_sm_90a
+}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
new file mode 100644
index 0000000000000..bf5c349a9aa7b
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Just check these don't emit errors.
+gpu.module @check_valid_SM_exact [#nvvm.target<chip = "sm_80">] {
+  test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_greater_1 [#nvvm.target<chip = "sm_86">] {
+  test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
+  test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_arch_acc [#nvvm.target<chip = "sm_90a">] {
+  test.nvvm_requires_sm_90a
+}
+
+// -----
+
+gpu.module @check_invalid_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
+  // expected-error @below {{is not supported on sm_70}}
+  test.nvvm_requires_sm_80
+}
+
+// -----
+
+gpu.module @check_invalid_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {
+  // expected-error @below {{is not supported on sm_75}}
+  test.nvvm_requires_sm_80
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
+  // expected-error @below {{is not supported on sm_90}}
+  test.nvvm_requires_sm_90a
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
+  // expected-error @below {{is not supported on sm_80}}
+  test.nvvm_requires_sm_90a
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 618b13da9899f..b1ffbcc7df9a2 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -85,6 +85,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
   MLIRLinalgDialect
   MLIRLinalgTransforms
   MLIRLLVMDialect
+  MLIRNVVMDialect
   MLIRPass
   MLIRPolynomialDialect
   MLIRReduce
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index f070c3bedd92c..8cffce27353fd 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2aa0658ab0e5d..4f922e62e2b1a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -13,6 +13,7 @@ include "TestDialect.td"
 include "TestInterfaces.td"
 include "mlir/Dialect/DLTI/DLTIBase.td"
 include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
+include "mlir/Dialect/LLVMIR/NVVMTraits.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/FunctionInterfaces.td"
 include "mlir/IR/OpBase.td"
@@ -2698,6 +2699,22 @@ def TestLinalgFillOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test NVVM RequiresSM trait.
+//===----------------------------------------------------------------------===//
+
+def TestNVVMRequiresSMOp : TEST_Op<"nvvm_requires_sm_80",
+                                    [NVVMRequiresSM<80>]> {
+  let arguments = (ins );
+  let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
+                                          [NVVMRequiresSM<90, "true">]> {
+  let arguments = (ins );
+  let assemblyFormat = "attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // Test Ops with Default-Valued String Attributes
 //===----------------------------------------------------------------------===//

@github-actions
Copy link

github-actions bot commented Feb 12, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-NVVMRequiresSM-trait branch from 067cd8b to 8fb1029 Compare February 12, 2025 11:14
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-NVVMRequiresSM-trait branch from 8fb1029 to 326a898 Compare February 21, 2025 04:18
@joker-eph
Copy link
Collaborator

joker-eph commented Mar 11, 2025

Can you add the motivation and usage in the description of the PR? Right now this describe what it does but not really provide context on why.

The check can be disabled with the existing --mlir-very-unsafe-disable-verifier-on-parsing flag (tests added in nvvm-check-targetSM-disabled.mlir)

Please remove this (the test and the mention in the PR description): 1) as the name indicates, this isn't something we'd want anyone to use, and 2) this is misleading: it does not disable any check other than on initial parsing.

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-NVVMRequiresSM-trait branch from 326a898 to 74d914d Compare March 11, 2025 10:15
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-NVVMRequiresSM-trait branch from 74d914d to f989ddc Compare March 14, 2025 09:35
@llvmbot llvmbot added mlir:gpu bazel "Peripheral" support tier build system: utils/bazel labels Mar 14, 2025
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-NVVMRequiresSM-trait branch from f989ddc to 0d542ca Compare March 14, 2025 09:39
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, with some comments.

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-NVVMRequiresSM-trait branch from 229998e to 7305d3a Compare March 14, 2025 11:58
@joker-eph
Copy link
Collaborator

joker-eph commented Mar 14, 2025

LG, but I'd like to see an approval from either @durga4github or @grypp here. Thanks!

Also please make sure the description reflects the latest changes.

@durga4github
Copy link
Contributor

LG, but I'd like to see an approval from either @durga4github or @grypp here. Thanks!

I am working closely with the author on this ;-)
Hence, we wanted to get reviews from you and Guray!

Also please make sure the description reflects the latest changes.

@Wolfram70 Wolfram70 self-assigned this Mar 14, 2025
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the minimum SM that llvm supports? This could be the default attribute, so we don't have to set sm20 everywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like sm_20 is the minimum version for the Ops, so made that as the default in the verifyTarget function and removed it from the Ops. Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we mark at least all sm90? So we have at least a good test coverage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have marked more of the new Ops for sm_90 and sm_100 in the latest revision, thanks!

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-NVVMRequiresSM-trait branch from ac8d54c to a4b7f67 Compare April 15, 2025 09:04
@@ -0,0 +1,38 @@
//===-- NVVMTraits.td - NVVM Traits ------------------------*- tablegen -*-===//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename the file to NVVMRequiresSMTraits.td

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed the files in the latest revision.

!cast<string>(minVersion) # "," # isArchAccelerated # ","
# exactMatch>;

def NVVMRequiresSM90a : NVVMRequiresSM<90, "true", "true">;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without a one can just use this:

NVVMRequiresSM<90> 

but if there is a we define a shortcut that's not intuitive to find out.

NVVMRequiresSM90a

Can we have this?

NVVMRequiresSMa<90> 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, that looks better. Changed it in the latest revision. Thanks!

Wolfram70 and others added 11 commits May 15, 2025 11:22
This change adds the NVVMRequiresSM op trait to the NVVM dialect to
allow tagging NVVM Ops with a minimum required SM version. When a
target SM is able to be determined (through NVVMTargetAttr), this
allows the verification of SM compatibility with the Op without needing
to unnecessarily lower any further down.
This change adds the NVVMRequiresSM op trait to the NVVM dialect to
allow tagging NVVM Ops with a minimum required SM version. When a
target SM is able to be determined (through NVVMTargetAttr), this
allows the verification of SM compatibility with the Op without needing
to unnecessarily lower any further down.
Apply suggestions from code review

Co-authored-by: Guray Ozen <[email protected]>
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-NVVMRequiresSM-trait branch from 8b90730 to 46c8353 Compare May 15, 2025 05:52
@Wolfram70 Wolfram70 changed the title [MLIR][NVVM] Add NVVMRequiresSM op trait [MLIR][NVVM] Add NVVMRequiresSM op traits May 19, 2025
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-NVVMRequiresSM-trait branch from 04545b1 to 140c37d Compare May 19, 2025 10:25
@grypp
Copy link
Member

grypp commented May 19, 2025

Thanks for bearing with me and handling all the code reviews—I think the PR is ready to land!

@Wolfram70
Copy link
Contributor Author

Wolfram70 commented May 20, 2025

No worries. Thanks!

@Wolfram70 Wolfram70 merged commit 9a553d3 into llvm:main May 21, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bazel "Peripheral" support tier build system: utils/bazel mlir:gpu mlir:llvm mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants