-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][NVVM] Add NVVMRequiresSM op traits #126886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add NVVMRequiresSM op traits #126886
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds the Summary:
Full diff: https://github.com/llvm/llvm-project/pull/126886.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index a9270c6f52344..8702129f1fdef 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -17,6 +17,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index fe15a524ec3b5..b2eede7f797dd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -16,6 +16,7 @@
include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/LLVMIR/NVVMTraits.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -136,8 +137,10 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
let assemblyFormat = "attr-dict `:` type($res)";
}
-class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
- NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
+ NVVM_SpecialRegisterOp<mnemonic,
+ !listconcat(traits,
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -167,14 +170,14 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
-def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
+def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid", [NVVMRequiresSM<20>]>;
def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
//===----------------------------------------------------------------------===//
// Lane Mask Comparison Ops
-def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
+def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq", [NVVMRequiresSM<20>]>;
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
@@ -200,7 +203,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
//===----------------------------------------------------------------------===//
// CTA Cluster index and range
-def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
+def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
@@ -210,7 +213,7 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
//===----------------------------------------------------------------------===//
// CTA index and range within Cluster
-def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
+def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
@@ -269,7 +272,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
def NVVM_ReduxOp :
- NVVM_Op<"redux.sync">,
+ NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]>,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$val,
ReduxKindAttr:$kind,
@@ -2327,7 +2330,8 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
+ [NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
new file mode 100644
index 0000000000000..40c2bcd50f235
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.h
@@ -0,0 +1,91 @@
+//===--- NVVMTraits.h - NVVM Traits -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines op traits for the NVVM Dialect in MLIR
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
+#define NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/StringExtras.h"
+
+namespace mlir {
+
+namespace NVVM {
+
+struct NVVMCheckSMVersion {
+ int ArchVersion;
+ bool ArchAccelerated;
+ std::string ArchString;
+
+ NVVMCheckSMVersion() {}
+ NVVMCheckSMVersion(StringRef SMVersion) : ArchString(SMVersion) {
+ parse(SMVersion);
+ }
+ NVVMCheckSMVersion(int ArchVersion, bool ArchAccelerated)
+ : ArchVersion(ArchVersion), ArchAccelerated(ArchAccelerated) {
+ ArchString = (llvm::Twine("sm_") + llvm::Twine(ArchVersion) +
+ (ArchAccelerated ? "a" : "\0"))
+ .str();
+ }
+
+ const StringRef getArchString() const { return ArchString; }
+
+ void parse(StringRef SMVersion) {
+ ArchAccelerated = (SMVersion[SMVersion.size() - 1] == 'a');
+ SMVersion.drop_front(3)
+ .take_while([](char c) { return llvm::isDigit(c); })
+ .getAsInteger(10, ArchVersion);
+ }
+
+ bool isCompatible(const NVVMCheckSMVersion &TargetSM) const {
+ // for arch-conditional SMs, they should exactly match to be valid
+ if (ArchAccelerated || TargetSM.ArchAccelerated)
+ return (*this) == TargetSM;
+
+ return ArchVersion <= TargetSM.ArchVersion;
+ }
+
+ bool operator==(const NVVMCheckSMVersion &Other) const {
+ return ArchVersion == Other.ArchVersion &&
+ ArchAccelerated == Other.ArchAccelerated;
+ }
+};
+
+llvm::SmallVector<NVVMCheckSMVersion> getTargetSMVersions(Operation *op);
+
+LogicalResult
+verifyOpSMRequirements(Operation *op,
+ llvm::SmallVector<NVVMCheckSMVersion> TargetSMVersions,
+ NVVMCheckSMVersion RequiredSMVersion);
+} // namespace NVVM
+
+namespace OpTrait {
+
+template <int Version, bool ArchAccelerated = false>
+class NVVMRequiresSM {
+public:
+ template <typename ConcreteOp>
+ class Impl : public OpTrait::TraitBase<
+ ConcreteOp, NVVMRequiresSM<Version, ArchAccelerated>::Impl> {
+ public:
+ static LogicalResult verifyTrait(Operation *op) {
+ NVVM::NVVMCheckSMVersion RequiredSMVersion(Version, ArchAccelerated);
+ llvm::SmallVector<NVVM::NVVMCheckSMVersion> TargetSMVersions =
+ NVVM::getTargetSMVersions(op);
+
+ return NVVM::verifyOpSMRequirements(op, TargetSMVersions,
+ RequiredSMVersion);
+ }
+ };
+};
+} // namespace OpTrait
+} // namespace mlir
+#endif // NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
new file mode 100644
index 0000000000000..7b2b43e88dc57
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMTraits.td
@@ -0,0 +1,22 @@
+//===-- NVVMTraits.td - NVVM Traits ------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines traits for the NVVM Dialect in MLIR
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVM_TRAITS
+#define NVVM_TRAITS
+
+include "mlir/IR/OpBase.td"
+
+class NVVMRequiresSM<int Version, string ArchAccelerated = "false"> :
+ ParamNativeOpTrait<"NVVMRequiresSM",
+ !cast<string>(Version) # "," # ArchAccelerated>;
+
+#endif //NVVM_TRAITS
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index c9a3b97294562..0d14dea3ca168 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
+ MLIRGPUDialect
MLIRSideEffectInterfaces
MLIRInferIntRangeInterface
)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 62f0c21338111..9315382727f89 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -18,6 +18,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -1439,6 +1440,36 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// Requires minimum target SM trait helper functions
+//===----------------------------------------------------------------------===//
+llvm::SmallVector<NVVMCheckSMVersion> NVVM::getTargetSMVersions(Operation *op) {
+ llvm::SmallVector<NVVMCheckSMVersion> TargetSMVersions;
+ gpu::GPUModuleOp GPUModule = op->getParentOfType<gpu::GPUModuleOp>();
+ if (GPUModule && GPUModule->hasAttr("targets")) {
+ ArrayAttr Targets = dyn_cast<ArrayAttr>(GPUModule->getAttr("targets"));
+ for (auto Target : Targets) {
+ if (auto NVVMTarget = dyn_cast<NVVMTargetAttr>(Target))
+ TargetSMVersions.push_back(NVVMCheckSMVersion(NVVMTarget.getChip()));
+ }
+ }
+ return TargetSMVersions;
+}
+
+// Helper function to verify the minimum SM requirement of an NVVM Op
+LogicalResult NVVM::verifyOpSMRequirements(
+ Operation *op, llvm::SmallVector<NVVMCheckSMVersion> TargetSMVersions,
+ NVVMCheckSMVersion RequiredSMVersion) {
+ for (auto TargetSMVersion : TargetSMVersions) {
+ if (!RequiredSMVersion.isCompatible(TargetSMVersion)) {
+ op->emitOpError() << "is not supported on "
+ << TargetSMVersion.getArchString();
+ return failure();
+ }
+ }
+ return success();
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM-disabled.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM-disabled.mlir
new file mode 100644
index 0000000000000..52dabd8c285f1
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM-disabled.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s --mlir-very-unsafe-disable-verifier-on-parsing -verify-diagnostics
+
+// Just check these don't emit errors.
+
+gpu.module @check_invalid_disabled_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
+ test.nvvm_requires_sm_80
+}
+
+gpu.module @check_invalid_disabled_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {
+ test.nvvm_requires_sm_80
+}
+
+gpu.module @check_invalid_disabled_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
+ test.nvvm_requires_sm_90a
+}
+
+gpu.module @check_invalid_disabled_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
+ test.nvvm_requires_sm_90a
+}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
new file mode 100644
index 0000000000000..bf5c349a9aa7b
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Just check these don't emit errors.
+gpu.module @check_valid_SM_exact [#nvvm.target<chip = "sm_80">] {
+ test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_greater_1 [#nvvm.target<chip = "sm_86">] {
+ test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
+ test.nvvm_requires_sm_80
+}
+
+gpu.module @check_valid_SM_arch_acc [#nvvm.target<chip = "sm_90a">] {
+ test.nvvm_requires_sm_90a
+}
+
+// -----
+
+gpu.module @check_invalid_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
+ // expected-error @below {{is not supported on sm_70}}
+ test.nvvm_requires_sm_80
+}
+
+// -----
+
+gpu.module @check_invalid_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {
+ // expected-error @below {{is not supported on sm_75}}
+ test.nvvm_requires_sm_80
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
+ // expected-error @below {{is not supported on sm_90}}
+ test.nvvm_requires_sm_90a
+}
+
+// -----
+
+gpu.module @check_invalid_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
+ // expected-error @below {{is not supported on sm_80}}
+ test.nvvm_requires_sm_90a
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 618b13da9899f..b1ffbcc7df9a2 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -85,6 +85,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
MLIRLinalgDialect
MLIRLinalgTransforms
MLIRLLVMDialect
+ MLIRNVVMDialect
MLIRPass
MLIRPolynomialDialect
MLIRReduce
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index f070c3bedd92c..8cffce27353fd 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2aa0658ab0e5d..4f922e62e2b1a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -13,6 +13,7 @@ include "TestDialect.td"
include "TestInterfaces.td"
include "mlir/Dialect/DLTI/DLTIBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
+include "mlir/Dialect/LLVMIR/NVVMTraits.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/IR/OpBase.td"
@@ -2698,6 +2699,22 @@ def TestLinalgFillOp :
}];
}
+//===----------------------------------------------------------------------===//
+// Test NVVM RequiresSM trait.
+//===----------------------------------------------------------------------===//
+
+def TestNVVMRequiresSMOp : TEST_Op<"nvvm_requires_sm_80",
+ [NVVMRequiresSM<80>]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
+def TestNVVMRequiresSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
+ [NVVMRequiresSM<90, "true">]> {
+ let arguments = (ins );
+ let assemblyFormat = "attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// Test Ops with Default-Valued String Attributes
//===----------------------------------------------------------------------===//
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
067cd8b to
8fb1029
Compare
8fb1029 to
326a898
Compare
|
Can you add the motivation and usage in the description of the PR? Right now this describe what it does but not really provide context on why.
Please remove this (the test and the mention in the PR description): 1) as the name indicates, this isn't something we'd want anyone to use, and 2) this is misleading: it does not disable any check other than on initial parsing. |
326a898 to
74d914d
Compare
74d914d to
f989ddc
Compare
f989ddc to
0d542ca
Compare
joker-eph
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM, with some comments.
229998e to
7305d3a
Compare
|
LG, but I'd like to see an approval from either @durga4github or @grypp here. Thanks! Also please make sure the description reflects the latest changes. |
I am working closely with the author on this ;-)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the minimum SM that llvm supports? This could be the default attribute, so we don't have to set sm20 everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like sm_20 is the minimum version for the Ops, so made that as the default in the verifyTarget function and removed it from the Ops. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we mark at least all sm90? So we have at least a good test coverage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have marked more of the new Ops for sm_90 and sm_100 in the latest revision, thanks!
ac8d54c to
a4b7f67
Compare
| @@ -0,0 +1,38 @@ | |||
| //===-- NVVMTraits.td - NVVM Traits ------------------------*- tablegen -*-===// | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename the file to NVVMRequiresSMTraits.td
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed the files in the latest revision.
| !cast<string>(minVersion) # "," # isArchAccelerated # "," | ||
| # exactMatch>; | ||
|
|
||
| def NVVMRequiresSM90a : NVVMRequiresSM<90, "true", "true">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without a one can just use this:
NVVMRequiresSM<90>
but if there is a we define a shortcut that's not intuitive to find out.
NVVMRequiresSM90a
Can we have this?
NVVMRequiresSMa<90>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, that looks better. Changed it in the latest revision. Thanks!
This change adds the NVVMRequiresSM op trait to the NVVM dialect to allow tagging NVVM Ops with a minimum required SM version. When a target SM is able to be determined (through NVVMTargetAttr), this allows the verification of SM compatibility with the Op without needing to unnecessarily lower any further down.
This change adds the NVVMRequiresSM op trait to the NVVM dialect to allow tagging NVVM Ops with a minimum required SM version. When a target SM is able to be determined (through NVVMTargetAttr), this allows the verification of SM compatibility with the Op without needing to unnecessarily lower any further down.
Apply suggestions from code review Co-authored-by: Guray Ozen <[email protected]>
Co-authored-by: Guray Ozen <[email protected]>
8b90730 to
46c8353
Compare
Co-authored-by: Guray Ozen <[email protected]>
04545b1 to
140c37d
Compare
|
Thanks for bearing with me and handling all the code reviews—I think the PR is ready to land! |
|
No worries. Thanks! |
Motivation:
Currently, the NVVMOps are not verified against the supported SM architectures. This can manifest as an ISel failure in the NVPTX LLVM backend during CodeGen to PTX ISA. This PR addresses this issue by adding verifier checks for Target-SM architectures in the NVVM Dialect itself, thereby catching the errors early on.
Summary:
NVVMRequiresSMandNVVMRequiresSMaare added to facilitate the version checks for typical and arch-accelerated versions respectively.TargetAttrVerifyInterfaceis added to the GPU dialect which any target attribute seeking to perform target-verification on the module can implement.NVVMTargetAttr(implementing theTargetAttrVerifyInterfaceinterface) when called from the GPU module verifier where it walks through the module and performs the checks for Ops with theNVVMRequiresSMtraits.NVVMOps.tdhave been updated to serve as examples.Example Usage: