Skip to content

Commit 574c895

Browse files
committed
add NVVMRequiresSMa for arch-accelerated SM versions
1 parent 46c8353 commit 574c895

File tree

6 files changed

+115
-83
lines changed

6 files changed

+115
-83
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2584,7 +2584,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
25842584
// NVVM Wgmma Ops
25852585
//===----------------------------------------------------------------------===//
25862586

2587-
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<90>]> {
2587+
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>]> {
25882588
let arguments = (ins);
25892589
let description = [{
25902590
Enforce an ordering of register accesses between warpgroup level matrix
@@ -2598,7 +2598,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<9
25982598
}];
25992599
}
26002600

2601-
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<90>]> {
2601+
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
26022602
let assemblyFormat = "attr-dict";
26032603
let description = [{
26042604
Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -2610,7 +2610,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [N
26102610
}];
26112611
}
26122612

2613-
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<90>]> {
2613+
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
26142614
let arguments = (ins I64Attr:$group);
26152615
let assemblyFormat = "attr-dict $group";
26162616
let description = [{
@@ -2973,7 +2973,7 @@ def Tcgen05WaitKindAttr :
29732973
let assemblyFormat = "`<` $value `>`";
29742974
}
29752975

2976-
def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSM<100, "true", "false">]> {
2976+
def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]> {
29772977
let summary = "Tcgen05 alloc operation";
29782978
let description = [{
29792979
The `tcgen05.alloc` Op allocates tensor core memory for
@@ -3003,7 +3003,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSM<100, "true",
30033003
}];
30043004
}
30053005

3006-
def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSM<100, "true", "false">]> {
3006+
def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 101]>]> {
30073007
let summary = "Tcgen05 dealloc operation";
30083008
let description = [{
30093009
The `tcgen05.dealloc` Op de-allocates the tensor core memory
@@ -3031,7 +3031,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSM<100, "tru
30313031
}];
30323032
}
30333033

3034-
def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSM<100, "true", "false">]> {
3034+
def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMa<[100, 101]>]> {
30353035
let summary = "Tcgen05 Op to relinquish the right to allocate";
30363036
let description = [{
30373037
The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
@@ -3054,7 +3054,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
30543054
}];
30553055
}
30563056

3057-
def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSM<100, "true", "false">]> {
3057+
def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>]> {
30583058
let summary = "Tcgen05 fence operations";
30593059
let description = [{
30603060
The `tcgen05.fence<before>` orders all prior async tcgen05 operations
@@ -3076,7 +3076,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSM<100, "true",
30763076
}];
30773077
}
30783078

3079-
def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSM<100, "true", "false">]> {
3079+
def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>]> {
30803080
let summary = "Tcgen05 wait operations";
30813081
let description = [{
30823082
The `tcgen05.wait<load>` causes the executing thread to block until
@@ -3098,7 +3098,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSM<100, "true", "f
30983098
}];
30993099
}
31003100

3101-
def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSM<100, "true", "false">]> {
3101+
def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]>]> {
31023102
let summary = "Tcgen05 commit operations";
31033103
let description = [{
31043104
The `tcgen05.commit` makes the mbarrier object, specified by
@@ -3136,7 +3136,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSM<100, "true"
31363136
}];
31373137
}
31383138

3139-
def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSM<100, "true", "false">]> {
3139+
def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 103]>]> {
31403140
let summary = "Tcgen05 shift operation";
31413141
let description = [{
31423142
The `tcgen05.shift` is an asynchronous instruction which initiates
@@ -3202,7 +3202,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05
32023202
let assemblyFormat = "`<` $value `>`";
32033203
}
32043204

3205-
def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSM<100, "true", "false">]> {
3205+
def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
32063206
let summary = "Tcgen05 copy operation";
32073207
let description = [{
32083208
Instruction tcgen05.cp initiates an asynchronous copy operation from
@@ -3272,7 +3272,7 @@ def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst
32723272
// NVVM tcgen05.ld Op
32733273
//===----------------------------------------------------------------------===//
32743274

3275-
def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSM<100, "true", "false">]> {
3275+
def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> {
32763276
let summary = "tensor memory load instructions";
32773277
let arguments = (ins
32783278
// Attributes
@@ -3362,7 +3362,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSM<100, "true", "false
33623362
// NVVM tcgen05.st Op
33633363
//===----------------------------------------------------------------------===//
33643364

3365-
def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSM<100, "true", "false">]> {
3365+
def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
33663366
let summary = "tensor memory store instructions";
33673367
let arguments = (ins
33683368
// Attributes

mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,52 +21,60 @@ namespace mlir {
2121

2222
namespace NVVM {
2323

24-
// Structure to store and check compatibility of SM versions.
24+
// Struct to store and check compatibility of SM versions.
2525
struct NVVMCheckSMVersion {
26-
int archVersion;
27-
2826
// Set to true if the SM version is accelerated (e.g., sm_90a).
2927
bool archAccelerated;
3028

31-
// Set to true if the target SM version must match exactly
32-
// (both archVersion and archAccelerated).
33-
// For example, sm_90a with exactMatch = false will also match
34-
// sm_100a, sm_120a, etc.
35-
bool exactMatch;
36-
37-
NVVMCheckSMVersion()
38-
: archVersion(0), archAccelerated(false), exactMatch(false) {}
39-
NVVMCheckSMVersion(StringRef smVersion, bool exactMatch = false)
40-
: exactMatch(exactMatch) {
41-
parse(smVersion);
42-
}
43-
NVVMCheckSMVersion(int archVersion, bool archAccelerated, bool exactMatch)
44-
: archVersion(archVersion), archAccelerated(archAccelerated),
45-
exactMatch(exactMatch) {}
46-
47-
// Parses the SM version string and sets the archVersion (as an integer)
48-
// and the archAccelerated flag.
49-
void parse(StringRef smVersion) {
50-
archAccelerated = (smVersion.back() == 'a');
51-
smVersion.drop_front(3)
52-
.take_while([](char c) { return llvm::isDigit(c); })
53-
.getAsInteger(10, archVersion);
29+
// List of SM versions.
30+
// Typically only has one version except for cases where multiple
31+
// arch-accelerated versions are supported.
32+
// For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
33+
llvm::SmallVector<int, 1> smVersionList;
34+
35+
template <typename... Ints>
36+
NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions)
37+
: archAccelerated(archAccelerated), smVersionList({smVersions...}) {
38+
assert((archAccelerated || smVersionList.size() == 1) &&
39+
"non arch-accelerated SM version list must be a single version!");
5440
}
5541

56-
bool isCompatible(const NVVMCheckSMVersion &targetSM) const {
57-
if (exactMatch)
58-
return (*this) == targetSM;
42+
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
43+
assert(targetSM.smVersionList.size() == 1 &&
44+
"target SM version list must be a single version!");
5945

60-
return archAccelerated
61-
? archVersion <= targetSM.archVersion && targetSM.archAccelerated
62-
: archVersion <= targetSM.archVersion;
46+
if (archAccelerated) {
47+
if (!targetSM.archAccelerated)
48+
return false;
49+
50+
for (auto version : smVersionList) {
51+
if (version == targetSM.smVersionList[0])
52+
return true;
53+
}
54+
} else {
55+
return targetSM.smVersionList[0] >= smVersionList[0];
56+
}
57+
58+
return false;
6359
}
6460

65-
bool operator==(const NVVMCheckSMVersion &other) const {
66-
return archVersion == other.archVersion &&
67-
archAccelerated == other.archAccelerated;
61+
bool isMinimumSMVersion() const { return smVersionList[0] >= 20; }
62+
63+
// Parses an SM version string and returns an equivalent NVVMCheckSMVersion
64+
// object.
65+
static const NVVMCheckSMVersion
66+
getTargetSMVersionFromStr(StringRef smVersionString) {
67+
bool isAA = smVersionString.back() == 'a';
68+
69+
int smVersionInt;
70+
smVersionString.drop_front(3)
71+
.take_while([](char c) { return llvm::isDigit(c); })
72+
.getAsInteger(10, smVersionInt);
73+
74+
return NVVMCheckSMVersion(isAA, smVersionInt);
6875
}
6976
};
77+
7078
} // namespace NVVM
7179
} // namespace mlir
7280

@@ -76,21 +84,34 @@ namespace mlir {
7684

7785
namespace OpTrait {
7886

79-
template <int MinVersion, bool ArchAccelerated = false, bool ExactMatch = false>
87+
template <int MinVersion>
8088
class NVVMRequiresSM {
8189
public:
8290
template <typename ConcreteOp>
8391
class Impl
84-
: public OpTrait::TraitBase<
85-
ConcreteOp,
86-
NVVMRequiresSM<MinVersion, ArchAccelerated, ExactMatch>::Impl>,
92+
: public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSM<MinVersion>::Impl>,
8793
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
8894
public:
8995
const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
90-
return NVVM::NVVMCheckSMVersion(MinVersion, ArchAccelerated, ExactMatch);
96+
return NVVM::NVVMCheckSMVersion(false, MinVersion);
9197
}
9298
};
9399
};
100+
101+
template <int... SMVersions>
102+
class NVVMRequiresSMa {
103+
public:
104+
template <typename ConcreteOp>
105+
class Impl : public OpTrait::TraitBase<ConcreteOp,
106+
NVVMRequiresSMa<SMVersions...>::Impl>,
107+
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
108+
public:
109+
const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
110+
return NVVM::NVVMCheckSMVersion(true, SMVersions...);
111+
}
112+
};
113+
};
114+
94115
} // namespace OpTrait
95116
} // namespace mlir
96117
#endif // NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_

mlir/include/mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,17 @@ def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
2727
];
2828
}
2929

30-
class NVVMRequiresSM<int minVersion, string isArchAccelerated = "false",
31-
string exactMatch = "false"> :
32-
ParamNativeOpTrait<"NVVMRequiresSM",
33-
!cast<string>(minVersion) # "," # isArchAccelerated # ","
34-
# exactMatch>;
35-
36-
class NVVMRequiresSMa<int version> : NVVMRequiresSM<version, "true", "true">;
30+
class NVVMRequiresSM<int minVersion> :
31+
ParamNativeOpTrait<"NVVMRequiresSM", !cast<string>(minVersion)>;
32+
33+
class StrJoin<string sep, list<string> str_list> {
34+
string ret = !foldl("", str_list, a, b,
35+
!if(!eq(a, ""), b, !if(!eq(b, ""), a, !strconcat(a, sep, b))));
36+
}
37+
38+
class NVVMRequiresSMa<list<int> smVersions> :
39+
ParamNativeOpTrait<"NVVMRequiresSMa",
40+
StrJoin<",", !foreach(vers, smVersions,
41+
!cast<string>(vers))>.ret>;
3742

3843
#endif //NVVM_REQUIRES_SM_TRAITS

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,16 +1743,18 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
17431743
"NVVM target attribute must be attached to a GPU module");
17441744
}
17451745

1746-
NVVMCheckSMVersion targetSMVersion(getChip());
1747-
if (targetSMVersion.archVersion < 20) {
1746+
const NVVMCheckSMVersion targetSMVersion =
1747+
NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
1748+
if (!targetSMVersion.isMinimumSMVersion()) {
17481749
return emitError(gpuModule->getLoc(),
17491750
"Minimum NVVM target SM version is sm_20");
17501751
}
1751-
1752+
17521753
gpuModuleOp->walk([&](Operation *op) {
17531754
if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
1754-
NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
1755-
if (!requirement.isCompatible(targetSMVersion)) {
1755+
const NVVMCheckSMVersion requirement =
1756+
reqOp.getRequiredMinSMVersion();
1757+
if (!requirement.isCompatibleWith(targetSMVersion)) {
17561758
op->emitOpError() << "is not supported on " << getChip();
17571759
return WalkResult::interrupt();
17581760
}

mlir/test/Dialect/LLVMIR/nvvm-check-targetSM.mlir

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@ gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
1313
test.nvvm_requires_sm_80
1414
}
1515

16-
gpu.module @check_valid_SM_arch_acc_exact_1 [#nvvm.target<chip = "sm_90a">] {
16+
gpu.module @check_valid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90a">] {
1717
test.nvvm_requires_sm_90a
1818
}
1919

20-
gpu.module @check_valid_SM_arch_acc_atleast_1 [#nvvm.target<chip = "sm_90a">] {
21-
test.nvvm_requires_sm_atleast_90_aa
20+
gpu.module @check_valid_SM_arch_acc_2 [#nvvm.target<chip = "sm_90a">] {
21+
test.nvvm_requires_sm_80
22+
}
23+
24+
gpu.module @check_valid_SM_arch_acc_multi_1 [#nvvm.target<chip = "sm_90a">] {
25+
test.nvvm_requires_sm_90a_or_sm_100a
2226
}
2327

24-
gpu.module @check_valid_SM_arch_acc_atleast_2 [#nvvm.target<chip = "sm_100a">] {
25-
test.nvvm_requires_sm_atleast_90_aa
28+
gpu.module @check_valid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_100a">] {
29+
test.nvvm_requires_sm_90a_or_sm_100a
2630
}
2731

2832

@@ -35,7 +39,7 @@ gpu.module @disable_verify_target2 [#nvvm.target<chip = "sm_70", verifyTarget =
3539
}
3640

3741
gpu.module @disable_verify_target3 [#nvvm.target<chip = "sm_90", verifyTarget = false>] {
38-
test.nvvm_requires_sm_atleast_90_aa
42+
test.nvvm_requires_sm_90a_or_sm_100a
3943
}
4044

4145
// -----
@@ -54,28 +58,28 @@ gpu.module @check_invalid_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {
5458

5559
// -----
5660

57-
gpu.module @check_invalid_SM_arch_acc_exact_1 [#nvvm.target<chip = "sm_90">] {
61+
gpu.module @check_invalid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
5862
// expected-error @below {{is not supported on sm_90}}
5963
test.nvvm_requires_sm_90a
6064
}
6165

6266
// -----
6367

64-
gpu.module @check_invalid_SM_arch_acc_exact_2 [#nvvm.target<chip = "sm_80">] {
68+
gpu.module @check_invalid_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
6569
// expected-error @below {{is not supported on sm_80}}
6670
test.nvvm_requires_sm_90a
6771
}
6872

6973
// -----
7074

71-
gpu.module @check_invalid_SM_arch_acc_atleast_1 [#nvvm.target<chip = "sm_80">] {
75+
gpu.module @check_invalid_SM_arch_acc_multi_1 [#nvvm.target<chip = "sm_80">] {
7276
// expected-error @below {{is not supported on sm_80}}
73-
test.nvvm_requires_sm_atleast_90_aa
77+
test.nvvm_requires_sm_90a_or_sm_100a
7478
}
7579

7680
// -----
7781

78-
gpu.module @check_invalid_SM_arch_acc_atleast_2 [#nvvm.target<chip = "sm_90">] {
82+
gpu.module @check_invalid_SM_arch_acc_multi_2 [#nvvm.target<chip = "sm_90">] {
7983
// expected-error @below {{is not supported on sm_90}}
80-
test.nvvm_requires_sm_atleast_90_aa
84+
test.nvvm_requires_sm_90a_or_sm_100a
8185
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,20 +2805,20 @@ def TestLinalgFillOp :
28052805
// Test NVVM RequiresSM trait.
28062806
//===----------------------------------------------------------------------===//
28072807

2808-
def TestNVVMRequiresSMOp : TEST_Op<"nvvm_requires_sm_80",
2809-
[NVVMRequiresSM<80>]> {
2808+
def TestNVVMRequiresSMOp :
2809+
TEST_Op<"nvvm_requires_sm_80", [NVVMRequiresSM<80>]> {
28102810
let arguments = (ins );
28112811
let assemblyFormat = "attr-dict";
28122812
}
28132813

2814-
def TestNVVMRequiresAtleastSMArchCondOp :
2815-
TEST_Op<"nvvm_requires_sm_atleast_90_aa", [NVVMRequiresSM<90, "true">]> {
2814+
def TestNVVMRequiresSMArchCondOp :
2815+
TEST_Op<"nvvm_requires_sm_90a", [NVVMRequiresSMa<[90]>]> {
28162816
let arguments = (ins );
28172817
let assemblyFormat = "attr-dict";
28182818
}
28192819

2820-
def TestNVVMRequiresExactSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
2821-
[NVVMRequiresSMa<90>]> {
2820+
def TestNVVMRequirestSMArchCondMultiOp :
2821+
TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMa<[90, 100]>]> {
28222822
let arguments = (ins );
28232823
let assemblyFormat = "attr-dict";
28242824
}

0 commit comments

Comments
 (0)