Skip to content

Commit e5a00f5

Browse files
authored
[AMD] NFC: Drop version minor for AMD MFMA layout (#7285)
AMD's MFMA layout does not need version minor information like NVIDIA. It always defaults to 0 in the current codebase. The PR drops version minor and change to a single `version` parameter for MFMA layout.
1 parent 09d5113 commit e5a00f5

33 files changed

+151
-169
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,11 +1002,11 @@ An encoding for tensors that have been produced by MFMA matrix core instructions
10021002
available on AMD Instinct GPUs of CDNA architectures.
10031003

10041004
It is characterized by the following parameters:
1005-
- `versionMajor` and `versionMinor` indicates the GPU architecture:
1006-
- 1.0: gfx908, i.e. CDNA1
1007-
- 2.0: gfx90a: i.e. CDNA2
1008-
- 3.0: gfx942: CDNA3
1009-
- 4.0: gfx950: CDNA4
1005+
- `version` indicates the GPU architecture:
1006+
- 1: gfx908: CDNA1
1007+
- 2: gfx90a: CDNA2
1008+
- 3: gfx942: CDNA3
1009+
- 4: gfx950: CDNA4
10101010
- `warpsPerCTA` indicates the warp layout in the block.
10111011
- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction.
10121012
- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout
@@ -1096,8 +1096,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
10961096

10971097
let parameters = (
10981098
ins
1099-
"unsigned": $versionMajor,
1100-
"unsigned": $versionMinor,
1099+
"unsigned": $version,
11011100
ArrayRefParameter<"unsigned">:$warpsPerCTA,
11021101
"unsigned":$MDim,
11031102
"unsigned":$NDim,

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,8 +1315,7 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13151315
if (parser.parseGreater().failed())
13161316
return {};
13171317

1318-
unsigned versionMajor = 0;
1319-
unsigned versionMinor = 0;
1318+
unsigned version = 0;
13201319
SmallVector<unsigned> warpsPerCTA;
13211320
SmallVector<unsigned> instrShape;
13221321
bool isTransposed;
@@ -1325,12 +1324,8 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13251324
std::optional<SmallVector<unsigned>> CTAOrder;
13261325

13271326
for (const NamedAttribute &attr : dict) {
1328-
if (attr.getName() == "versionMajor") {
1329-
if (parseUInt(parser, attr, versionMajor, "versionMajor").failed())
1330-
return {};
1331-
}
1332-
if (attr.getName() == "versionMinor") {
1333-
if (parseUInt(parser, attr, versionMinor, "versionMinor").failed())
1327+
if (attr.getName() == "version") {
1328+
if (parseUInt(parser, attr, version, "verison").failed())
13341329
return {};
13351330
}
13361331
if (attr.getName() == "warpsPerCTA") {
@@ -1369,14 +1364,13 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13691364
return {};
13701365

13711366
return parser.getChecked<AMDMfmaEncodingAttr>(
1372-
parser.getContext(), versionMajor, versionMinor, warpsPerCTA,
1373-
instrShape[0], instrShape[1], isTransposed, *CTALayout);
1367+
parser.getContext(), version, warpsPerCTA, instrShape[0], instrShape[1],
1368+
isTransposed, *CTALayout);
13741369
}
13751370

13761371
void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
13771372
printer << "<{"
1378-
<< "versionMajor = " << getVersionMajor() //
1379-
<< ", versionMinor = " << getVersionMinor() //
1373+
<< "version = " << getVersion() //
13801374
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]" //
13811375
<< ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" //
13821376
<< ", isTransposed = " << getIsTransposed();
@@ -1385,17 +1379,12 @@ void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
13851379
printer << "}>";
13861380
}
13871381

1388-
LogicalResult
1389-
AMDMfmaEncodingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
1390-
unsigned versionMajor, unsigned versionMinor,
1391-
llvm::ArrayRef<unsigned int> warpsPerCTA,
1392-
unsigned mDim, unsigned nDim, bool isTransposed,
1393-
mlir::triton::gpu::CTALayoutAttr) {
1394-
if (!(versionMajor >= 0 && versionMajor <= 4)) {
1395-
return emitError() << "major version must be in the [0, 4] range";
1396-
}
1397-
if (versionMinor != 0) {
1398-
return emitError() << "minor version must be 0";
1382+
LogicalResult AMDMfmaEncodingAttr::verify(
1383+
function_ref<mlir::InFlightDiagnostic()> emitError, unsigned version,
1384+
llvm::ArrayRef<unsigned int> warpsPerCTA, unsigned mDim, unsigned nDim,
1385+
bool isTransposed, mlir::triton::gpu::CTALayoutAttr) {
1386+
if (!(version >= 0 && version <= 4)) {
1387+
return emitError() << "version must be in the [0, 4] range";
13991388
}
14001389
if (!((mDim == 32 && nDim == 32) || (mDim == 16 && nDim == 16))) {
14011390
return emitError()
@@ -1949,7 +1938,7 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand(
19491938
bool isKContig = sharedOrder[0] == kDimIndex;
19501939
// GFX950 supports LDS transpose load instructions, so we need swizzling even
19511940
// when K dimension is not the contiguous dimension.
1952-
bool isGFX950 = getVersionMajor() == 4;
1941+
bool isGFX950 = getVersion() == 4;
19531942
bool swizzleNonKContig =
19541943
isGFX950 && (elemBitWidth == 8 || elemBitWidth == 16);
19551944

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1525,7 +1525,7 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) {
15251525

15261526
Type elemType = valType.getElementType();
15271527
if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) &&
1528-
mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() &&
1528+
mfmaLayout.getVersion() == 4 && mfmaLayout.getIsTransposed() &&
15291529
(isMfma32 || validForMfma16)))
15301530
return {};
15311531

test/Analysis/amd/test-alignment.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: triton-opt %s -test-print-amd-alignment -split-input-file -verify-diagnostics=only-expected -o /dev/null
22

3-
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
3+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
44

55
tt.func public @kernel(%arg0: tensor<256x64xf16, #mma> {tt.contiguity=256 : i32, tt.divisibility=6: i32, tt.constancy=1: i32}) attributes {noinline = false} {
66
// expeted-remark @below {{contiguity = [128, 32], divisibility = [6, 6], constancy = [1, 1], constant_value = <none>}}

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
5959
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
6060
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
6161
#smem = #ttg.shared_memory
62-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
62+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
6363
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
6464
// COMMON-LABEL: @local_loads_with_token_from_async_wait
6565
tt.func public @local_loads_with_token_from_async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -98,7 +98,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
9898
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
9999
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
100100
#smem = #ttg.shared_memory
101-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
101+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
102102
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
103103
// COMMON-LABEL: @local_loads_without_token_from_async_wait
104104
tt.func public @local_loads_without_token_from_async_wait(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -137,7 +137,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
137137
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
138138
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
139139
#smem = #ttg.shared_memory
140-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
140+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
141141
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
142142
// COMMON-LABEL: @local_loads_with_loop_carried_token
143143
tt.func public @local_loads_with_loop_carried_token(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},

test/Conversion/amd/compute-base-ptr.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s
22

33
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4-
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 16], isTransposed = false}>
4+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 4], instrShape = [16, 16], isTransposed = false}>
55
#shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0]}>
66
#smem = #ttg.shared_memory
77
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 64 : i32} {

test/Conversion/amd/dedup-by-constancy.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// only allows duplication within each group of 4 elemnets. Therefore, we expect 4 icmp, one
1414
// for each group of 4 elements.
1515
// In the future, we can reduce the icmp to 2 in such case.
16-
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}>
16+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}>
1717
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
1818
tt.func public @dedup_by_constancy_mfma(%arg0: i32 {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1919
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>

test/Conversion/amd/ds_transpose.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s
22

3-
#mma16 = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}>
4-
#mma32 = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
3+
#mma16 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}>
4+
#mma32 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
55
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
66
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
77
#smem = #ttg.shared_memory

test/Conversion/amd/load_store.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
3030

3131
// -----
3232

33-
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
33+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
3434
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
3535
// CHECK-LABEL: global_store_mfma_vec16
3636
tt.func public @global_store_mfma_vec16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {

test/Conversion/amd/mfma-shortcut.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: triton-opt %s --tritongpu-reduce-data-duplication --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s --check-prefix=GFX942
22
// RUN: triton-opt %s --tritongpu-reduce-data-duplication --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx950" -split-input-file | FileCheck %s --check-prefix=GFX950
33

4-
#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
4+
#mfma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
55
#dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>
66
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
77
// GFX942-LABEL: shortcut_mfma16
@@ -16,7 +16,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
1616

1717
// -----
1818

19-
#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
19+
#mfma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
2020
#dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
2121
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
2222
// GFX942-LABEL: no_shortcut_mfma16
@@ -31,7 +31,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
3131

3232
// -----
3333

34-
#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
34+
#mfma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
3535
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
3636

3737
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
@@ -95,7 +95,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
9595

9696
// -----
9797

98-
#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
98+
#mfma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
9999
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
100100

101101
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
@@ -112,7 +112,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
112112

113113
// -----
114114

115-
#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
115+
#mfma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
116116
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
117117

118118
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
@@ -206,7 +206,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
206206

207207
// -----
208208

209-
#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
209+
#mfma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
210210
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
211211

212212
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
@@ -225,7 +225,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
225225

226226
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}>
227227
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[32, 0], [64, 0]], block = []}>
228-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
228+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
229229
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
230230
// GFX950-LABEL: mfma_linear_permlane_swap
231231
tt.func public @mfma_linear_permlane_swap(%arg0: tensor<128x128xf16, #mma>) attributes {noinline = false} {

0 commit comments

Comments
 (0)