Skip to content
Merged
26 changes: 13 additions & 13 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2584,7 +2584,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//

def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<90>]> {
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>]> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
Expand All @@ -2598,7 +2598,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<9
}];
}

def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<90>]> {
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
let assemblyFormat = "attr-dict";
let description = [{
Commits all prior uncommitted warpgroup level matrix multiplication operations.
Expand All @@ -2610,7 +2610,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [N
}];
}

def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<90>]> {
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
let arguments = (ins I64Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
Expand Down Expand Up @@ -2973,7 +2973,7 @@ def Tcgen05WaitKindAttr :
let assemblyFormat = "`<` $value `>`";
}

def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSM<100, "true", "false">]> {
def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 alloc operation";
let description = [{
The `tcgen05.alloc` Op allocates tensor core memory for
Expand Down Expand Up @@ -3003,7 +3003,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSM<100, "true",
}];
}

def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSM<100, "true", "false">]> {
def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 dealloc operation";
let description = [{
The `tcgen05.dealloc` Op de-allocates the tensor core memory
Expand Down Expand Up @@ -3031,7 +3031,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSM<100, "tru
}];
}

def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSM<100, "true", "false">]> {
def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 Op to relinquish the right to allocate";
let description = [{
The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
Expand All @@ -3054,7 +3054,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
}];
}

def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSM<100, "true", "false">]> {
def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 fence operations";
let description = [{
The `tcgen05.fence<before>` orders all prior async tcgen05 operations
Expand All @@ -3076,7 +3076,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSM<100, "true",
}];
}

def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSM<100, "true", "false">]> {
def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 wait operations";
let description = [{
The `tcgen05.wait<load>` causes the executing thread to block until
Expand All @@ -3098,7 +3098,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSM<100, "true", "f
}];
}

def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSM<100, "true", "false">]> {
def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 commit operations";
let description = [{
The `tcgen05.commit` makes the mbarrier object, specified by
Expand Down Expand Up @@ -3136,7 +3136,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSM<100, "true"
}];
}

def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSM<100, "true", "false">]> {
def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 103]>]> {
let summary = "Tcgen05 shift operation";
let description = [{
The `tcgen05.shift` is an asynchronous instruction which initiates
Expand Down Expand Up @@ -3202,7 +3202,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05
let assemblyFormat = "`<` $value `>`";
}

def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSM<100, "true", "false">]> {
def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 copy operation";
let description = [{
Instruction tcgen05.cp initiates an asynchronous copy operation from
Expand Down Expand Up @@ -3272,7 +3272,7 @@ def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst
// NVVM tcgen05.ld Op
//===----------------------------------------------------------------------===//

def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSM<100, "true", "false">]> {
def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "tensor memory load instructions";
let arguments = (ins
// Attributes
Expand Down Expand Up @@ -3362,7 +3362,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSM<100, "true", "false
// NVVM tcgen05.st Op
//===----------------------------------------------------------------------===//

def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSM<100, "true", "false">]> {
def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "tensor memory store instructions";
let arguments = (ins
// Attributes
Expand Down
101 changes: 61 additions & 40 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,60 @@ namespace mlir {

namespace NVVM {

// Structure to store and check compatibility of SM versions.
// Struct to store and check compatibility of SM versions.
struct NVVMCheckSMVersion {
int archVersion;

// Set to true if the SM version is accelerated (e.g., sm_90a).
bool archAccelerated;

// Set to true if the target SM version must match exactly
// (both archVersion and archAccelerated).
// For example, sm_90a with exactMatch = false will also match
// sm_100a, sm_120a, etc.
bool exactMatch;

NVVMCheckSMVersion()
: archVersion(0), archAccelerated(false), exactMatch(false) {}
NVVMCheckSMVersion(StringRef smVersion, bool exactMatch = false)
: exactMatch(exactMatch) {
parse(smVersion);
}
NVVMCheckSMVersion(int archVersion, bool archAccelerated, bool exactMatch)
: archVersion(archVersion), archAccelerated(archAccelerated),
exactMatch(exactMatch) {}

// Parses the SM version string and sets the archVersion (as an integer)
// and the archAccelerated flag.
void parse(StringRef smVersion) {
archAccelerated = (smVersion.back() == 'a');
smVersion.drop_front(3)
.take_while([](char c) { return llvm::isDigit(c); })
.getAsInteger(10, archVersion);
// List of SM versions.
// Typically only has one version except for cases where multiple
// arch-accelerated versions are supported.
// For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
llvm::SmallVector<int, 1> smVersionList;

template <typename... Ints>
NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions)
: archAccelerated(archAccelerated), smVersionList({smVersions...}) {
assert((archAccelerated || smVersionList.size() == 1) &&
"non arch-accelerated SM version list must be a single version!");
}

bool isCompatible(const NVVMCheckSMVersion &targetSM) const {
if (exactMatch)
return (*this) == targetSM;
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
assert(targetSM.smVersionList.size() == 1 &&
"target SM version list must be a single version!");

if (archAccelerated) {
if (!targetSM.archAccelerated)
return false;

for (auto version : smVersionList) {
if (version == targetSM.smVersionList[0])
return true;
}
} else {
return targetSM.smVersionList[0] >= smVersionList[0];
}

return archAccelerated
? archVersion <= targetSM.archVersion && targetSM.archAccelerated
: archVersion <= targetSM.archVersion;
return false;
}

bool operator==(const NVVMCheckSMVersion &other) const {
return archVersion == other.archVersion &&
archAccelerated == other.archAccelerated;
bool isMinimumSMVersion() const { return smVersionList[0] >= 20; }

// Parses an SM version string and returns an equivalent NVVMCheckSMVersion
// object.
static const NVVMCheckSMVersion
getTargetSMVersionFromStr(StringRef smVersionString) {
bool isAA = smVersionString.back() == 'a';

int smVersionInt;
smVersionString.drop_front(3)
.take_while([](char c) { return llvm::isDigit(c); })
.getAsInteger(10, smVersionInt);

return NVVMCheckSMVersion(isAA, smVersionInt);
}
};

} // namespace NVVM
} // namespace mlir

Expand All @@ -76,21 +84,34 @@ namespace mlir {

namespace OpTrait {

template <int MinVersion, bool ArchAccelerated = false, bool ExactMatch = false>
template <int MinVersion>
class NVVMRequiresSM {
public:
template <typename ConcreteOp>
class Impl
: public OpTrait::TraitBase<
ConcreteOp,
NVVMRequiresSM<MinVersion, ArchAccelerated, ExactMatch>::Impl>,
: public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSM<MinVersion>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
return NVVM::NVVMCheckSMVersion(MinVersion, ArchAccelerated, ExactMatch);
return NVVM::NVVMCheckSMVersion(false, MinVersion);
}
};
};

template <int... SMVersions>
class NVVMRequiresSMa {
public:
template <typename ConcreteOp>
class Impl : public OpTrait::TraitBase<ConcreteOp,
NVVMRequiresSMa<SMVersions...>::Impl>,
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
public:
const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
return NVVM::NVVMCheckSMVersion(true, SMVersions...);
}
};
};

} // namespace OpTrait
} // namespace mlir
#endif // NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
19 changes: 12 additions & 7 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
];
}

class NVVMRequiresSM<int minVersion, string isArchAccelerated = "false",
string exactMatch = "false"> :
ParamNativeOpTrait<"NVVMRequiresSM",
!cast<string>(minVersion) # "," # isArchAccelerated # ","
# exactMatch>;

class NVVMRequiresSMa<int version> : NVVMRequiresSM<version, "true", "true">;
class NVVMRequiresSM<int minVersion> :
ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion)>;

class StrJoin<string sep, list<string> str_list> {
string ret = !foldl("", str_list, a, b,
!if(!eq(a, ""), b, !if(!eq(b, ""), a, !strconcat(a, sep, b))));
}

class NVVMRequiresSMa<list<int> smVersions> :
ParamNativeOpTrait<"NVVMRequiresSMa",
StrJoin<",", !foreach(vers, smVersions,
!cast<string>(vers))>.ret>;

#endif //NVVM_REQUIRES_SM_TRAITS
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1743,16 +1743,17 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
"NVVM target attribute must be attached to a GPU module");
}

NVVMCheckSMVersion targetSMVersion(getChip());
if (targetSMVersion.archVersion < 20) {
const NVVMCheckSMVersion targetSMVersion =
NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
if (!targetSMVersion.isMinimumSMVersion()) {
return emitError(gpuModule->getLoc(),
"Minimum NVVM target SM version is sm_20");
}

gpuModuleOp->walk([&](Operation *op) {
if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
if (!requirement.isCompatible(targetSMVersion)) {
const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
if (!requirement.isCompatibleWith(targetSMVersion)) {
op->emitOpError() << "is not supported on " << getChip();
return WalkResult::interrupt();
}
Expand Down
28 changes: 16 additions & 12 deletions mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@ gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
test.nvvm_requires_sm_80
}

gpu.module @check_valid_SM_arch_acc_exact_1 [#nvvm.target<chip = "sm_90a">] {
gpu.module @check_valid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90a">] {
test.nvvm_requires_sm_90a
}

gpu.module @check_valid_SM_arch_acc_atleast_1 [#nvvm.target<chip = "sm_90a">] {
test.nvvm_requires_sm_atleast_90_aa
gpu.module @check_valid_SM_arch_acc_2 [#nvvm.target<chip = "sm_90a">] {
test.nvvm_requires_sm_80
}

gpu.module @check_valid_SM_arch_acc_multi_1 [#nvvm.target<chip = "sm_90a">] {
test.nvvm_requires_sm_90a_or_sm_100a
}

gpu.module @check_valid_SM_arch_acc_atleast_2 [#nvvm.target<chip = "sm_100a">] {
test.nvvm_requires_sm_atleast_90_aa
gpu.module @check_valid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_100a">] {
test.nvvm_requires_sm_90a_or_sm_100a
}


Expand All @@ -35,7 +39,7 @@ gpu.module @disable_verify_target2 [#nvvm.target<chip = "sm_70", verifyTarget =
}

gpu.module @disable_verify_target3 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
test.nvvm_requires_sm_atleast_90_aa
test.nvvm_requires_sm_90a_or_sm_100a
}

// -----
Expand All @@ -54,28 +58,28 @@ gpu.module @check_invalid_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {

// -----

gpu.module @check_invalid_SM_arch_acc_exact_1 [#nvvm.target<chip = "sm_90">] {
gpu.module @check_invalid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
// expected-error @below {{is not supported on sm_90}}
test.nvvm_requires_sm_90a
}

// -----

gpu.module @check_invalid_SM_arch_acc_exact_2 [#nvvm.target<chip = "sm_80">] {
gpu.module @check_invalid_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
// expected-error @below {{is not supported on sm_80}}
test.nvvm_requires_sm_90a
}

// -----

gpu.module @check_invalid_SM_arch_acc_atleast_1 [#nvvm.target<chip = "sm_80">] {
gpu.module @check_invalid_SM_arch_acc_multi_1 [#nvvm.target<chip = "sm_80">] {
// expected-error @below {{is not supported on sm_80}}
test.nvvm_requires_sm_atleast_90_aa
test.nvvm_requires_sm_90a_or_sm_100a
}

// -----

gpu.module @check_invalid_SM_arch_acc_atleast_2 [#nvvm.target<chip = "sm_90">] {
gpu.module @check_invalid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_90">] {
// expected-error @below {{is not supported on sm_90}}
test.nvvm_requires_sm_atleast_90_aa
test.nvvm_requires_sm_90a_or_sm_100a
}
12 changes: 6 additions & 6 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2805,20 +2805,20 @@ def TestLinalgFillOp :
// Test NVVM RequiresSM trait.
//===----------------------------------------------------------------------===//

def TestNVVMRequiresSMOp : TEST_Op<"nvvm_requires_sm_80",
[NVVMRequiresSM<80>]> {
def TestNVVMRequiresSMOp :
TEST_Op<"nvvm_requires_sm_80", [NVVMRequiresSM<80>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}

def TestNVVMRequiresAtleastSMArchCondOp :
TEST_Op<"nvvm_requires_sm_atleast_90_aa", [NVVMRequiresSM<90, "true">]> {
def TestNVVMRequiresSMArchCondOp :
TEST_Op<"nvvm_requires_sm_90a", [NVVMRequiresSMa<[90]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}

def TestNVVMRequiresExactSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
[NVVMRequiresSMa<90>]> {
def TestNVVMRequirestSMArchCondMultiOp :
TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMa<[90, 100]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
}
Expand Down