Skip to content

Commit f6626cd

Browse files
authored
[AMD] Support scaled dot for gfx12 (#7644)
Support emulation for scaled dot by decomposing it into normal dot with upcasting operands. Signed-off-by: Ilya Veselov <[email protected]>
1 parent d016691 commit f6626cd

File tree

12 files changed

+251
-155
lines changed

12 files changed

+251
-155
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2137,9 +2137,11 @@ LogicalResult DotOperandEncodingAttr::verify(
21372137

21382138
if (auto parentAttr = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
21392139
if (kWidth != 16 && parentAttr.getVersion() == 1 ||
2140-
kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2)
2140+
kWidth != 4 && kWidth != 8 && kWidth != 16 &&
2141+
parentAttr.getVersion() == 2)
21412142
return emitError() << "ttg.dot_op kWidth parameter must be 16 for "
2142-
"gfx11 and 8/16 for gfx12";
2143+
"gfx11 and 4/8/16 for gfx12 (including packed "
2144+
"cases for `scaled_dot`)";
21432145
return success();
21442146
}
21452147

lib/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
3030

3131
LogicalResult matchAndRewrite(DotScaledOp scaledDotOp,
3232
PatternRewriter &rewriter) const override {
33+
if (isa_and_nonnull<MmaEncodingTrait>(
34+
scaledDotOp.getResult().getType().getEncoding()))
35+
return failure();
36+
3337
// TODO: add support for m/n packed formats.
3438
if (!scaledDotOp.getLhsKPack() || !scaledDotOp.getRhsKPack())
3539
return failure();

python/test/unit/language/test_core.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def promotion_numpy_2_0():
8080
# 0 is a special value for automatic heuristic
8181
if is_hip_cdna():
8282
mma_nonk_sizes = [0, 16, 32]
83+
elif is_hip_gfx12():
84+
mma_nonk_sizes = [16]
8385
else:
8486
THREADS_PER_WARP = 32
8587

@@ -4196,12 +4198,12 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, nu
41964198
if cc < (8, 9):
41974199
pytest.skip("float8e4nv not supported on CUDA < 8.9")
41984200
if is_hip():
4199-
if not is_hip_cdna():
4200-
pytest.skip("scaled_dot only implemented for HIP CDNA")
4201+
if not (is_hip_cdna() or is_hip_gfx12()):
4202+
pytest.skip("scaled_dot only implemented for HIP CDNA and gfx12")
42014203
if "e4m3" in (mxfp_type, normal_type):
4202-
if not (is_hip_cdna3() or is_hip_cdna4()):
4203-
pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) only implemented for CDNA3 and CDNA4")
4204-
if mma == 16 and K == 64:
4204+
if not (is_hip_cdna3() or is_hip_cdna4() or is_hip_gfx12()):
4205+
pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) only implemented for CDNA3, CDNA4, gfx12")
4206+
if mma == 16 and K == 64 and not is_hip_gfx12():
42054207
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
42064208

42074209
@triton.jit

test/TritonGPU/invalid-attributes.mlir

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,26 +42,21 @@
4242

4343
// -----
4444

45-
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}}
45+
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 4/8/16 for gfx12 (including packed cases for `scaled_dot`)}}
4646
#wmma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}>
4747
#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma}>
4848

4949
// -----
5050

51-
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}}
51+
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 4/8/16 for gfx12 (including packed cases for `scaled_dot`)}}
5252
#wmma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}>
5353
#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 8}>
5454

5555
// -----
56-
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}}
56+
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 4/8/16 for gfx12 (including packed cases for `scaled_dot`)}}
5757
#wmma = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}>
5858
#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 32}>
5959

60-
// -----
61-
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}}
62-
#wmma = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}>
63-
#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 4}>
64-
6560
// -----
6661

6762
// expected-error@+1 {{version must be in the [0, 4] range}}

third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ add_triton_library(TritonAMDGPUToLLVM
2222
SPMDOpToLLVM.cpp
2323
SchedInstructions.cpp
2424
UpcastMXFPToLLVM.cpp
25+
Fp4ToFpOpToLLVM.cpp
2526
MembarUtility.cpp
2627
ScalarizePackedFOps.cpp
2728

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#include "PatternTritonGPUOpToLLVM.h"
2+
3+
#include "Utility.h"
4+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
5+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
6+
#include "mlir/IR/BuiltinOps.h"
7+
#include "mlir/IR/TypeUtilities.h"
8+
#include "mlir/IR/ValueRange.h"
9+
#include "mlir/Transforms/DialectConversion.h"
10+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
11+
#include "triton/Dialect/Triton/IR/Dialect.h"
12+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
13+
#include "llvm/ADT/STLExtras.h"
14+
#include "llvm/ADT/SmallVector.h"
15+
#include "llvm/Support/Debug.h"
16+
#include <array>
17+
18+
using namespace mlir;
19+
using namespace mlir::triton;
20+
using namespace mlir::triton::gpu;
21+
using ::mlir::LLVM::AMD::upcast8xMxfp4_SW;
22+
23+
namespace {
24+
25+
class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern<Fp4ToFpOp> {
26+
public:
27+
Fp4ToFpOpPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit)
28+
: ConvertOpToLLVMPattern<Fp4ToFpOp>(typeConverter, benefit) {}
29+
30+
LogicalResult
31+
matchAndRewrite(Fp4ToFpOp op, OpAdaptor adaptor,
32+
ConversionPatternRewriter &rewriter) const override {
33+
34+
auto loc = op.getLoc();
35+
auto elemType = op.getType().getElementType();
36+
assert(elemType == f16_ty || elemType == bf16_ty);
37+
bool toFp16 = elemType == f16_ty;
38+
39+
auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
40+
41+
SmallVector<Value> results;
42+
results.reserve(xVals.size() * 2);
43+
assert(xVals.size() % 4 == 0);
44+
auto b = TritonLLVMOpBuilder(loc, rewriter);
45+
for (int i = 0; i < xVals.size(); i += 4) {
46+
Value packedVec = b.undef(vec_ty(i8_ty, 4));
47+
for (int j : llvm::seq(4)) {
48+
Value v = xVals[i + j];
49+
packedVec = b.insert_element(packedVec, v, b.i32_val(j));
50+
}
51+
SmallVector<Value, 4> v4i32 =
52+
upcast8xMxfp4_SW(rewriter, op, toFp16, packedVec);
53+
for (int j = 0; j < 4; j++) {
54+
Value elements = b.bitcast(v4i32[j], vec_ty(elemType, 2));
55+
results.push_back(b.extract_element(elements, b.i32_val(0)));
56+
results.push_back(b.extract_element(elements, b.i32_val(1)));
57+
}
58+
}
59+
60+
Value result = packLLElements(loc, getTypeConverter(), results, rewriter,
61+
op.getType());
62+
rewriter.replaceOp(op, result);
63+
return success();
64+
}
65+
};
66+
} // anonymous namespace
67+
68+
void mlir::triton::AMD::populateFp4ToFpToLLVMPatterns(
69+
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
70+
PatternBenefit benefit) {
71+
patterns.add<Fp4ToFpOpPattern>(typeConverter, benefit);
72+
}

third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter,
4343
const TargetInfo &targetInfo,
4444
PatternBenefit benefit);
4545

46+
void populateFp4ToFpToLLVMPatterns(LLVMTypeConverter &typeConverter,
47+
RewritePatternSet &patterns,
48+
PatternBenefit benefit);
49+
4650
} // namespace mlir::triton::AMD
4751

4852
#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_PATTERNTRITONGPUOPTOLLVM_H_

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ struct ConvertTritonAMDGPUToLLVM
203203
patterns, AMDBenefit);
204204
mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns,
205205
targetInfo, AMDBenefit);
206+
mlir::triton::AMD::populateFp4ToFpToLLVMPatterns(typeConverter, patterns,
207+
AMDBenefit);
206208

207209
// TODO(thomas): this should probably be done in a separate step to not
208210
// interfere with our own lowering of arith ops. Add arith/math's patterns

third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 1 addition & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -19,149 +19,10 @@
1919
using namespace mlir;
2020
using namespace mlir::triton;
2121
using namespace mlir::triton::gpu;
22+
using ::mlir::LLVM::AMD::upcast8xMxfp4_SW;
2223

2324
namespace {
2425

25-
SmallVector<Value, 4> upcast8xMxfp4_SW(RewriterBase &rewriter,
26-
amdgpu::UpcastMXFPOp upcastOp,
27-
bool tofp16, Value packedVec) {
28-
Location loc = upcastOp.getLoc();
29-
auto b = TritonLLVMOpBuilder(loc, rewriter);
30-
31-
// MXFP4 has 4 bits, S.EE.M, for Sign, Exponent, and Mantissa respectively.
32-
// For a specific S, we have a total of 8 bit patterns. We can encode all
33-
// these 8 resultant bf16/fp16 bit patterns in a lookup table (LUT). It
34-
// happens that llvm.amdgcn.perm supports selecting 4 bytes from 8 input bytes
35-
// using a 4-byte selector. So the overall idea is to use llvm.amdgcn.perm to
36-
// implement such a LUT; though we need to select the two bytes for the
37-
// resultant bf16/fp16 bit patterns separately. For the byte containing S, we
38-
// also need to handle the S and E bits separately.
39-
40-
// FP4 has 4 bits: S.EE.M. Bf16/fp16 bit patterns for positive values:
41-
//
42-
// FP4 | BF16 | FP16 | Value
43-
// ------ | ------ | ------ | -----
44-
// 0.00.0 | 0x0000 | 0x0000 | + 0.0
45-
// 0.00.1 | 0x3f00 | 0x3800 | + 0.5
46-
// 0.01.0 | 0x3f80 | 0x3c00 | + 1.0
47-
// 0.01.1 | 0x3fc0 | 0x3e00 | + 1.5
48-
// 0.10.0 | 0x4000 | 0x4000 | + 2.0
49-
// 0.10.1 | 0x4040 | 0x4200 | + 3.0
50-
// 0.11.0 | 0x4080 | 0x4400 | + 4.0
51-
// 0.11.1 | 0x40c0 | 0x4600 | + 6.0
52-
//
53-
// Encode Byte #0 (M) for BF16/FP16 in a LUT.
54-
Value resB0LutLo = tofp16 ? b.i32_val(0) : b.i32_val(0xc0800000);
55-
Value resB0LutHi = tofp16 ? b.i32_val(0) : b.i32_val(0xc0804000);
56-
// Encode Byte #1 (EM, non-S part) for BF16/FP16 in a LUT.
57-
Value resB1LutLoNoS = tofp16 ? b.i32_val(0x3e3c3800) : b.i32_val(0x3f3f3f00);
58-
Value resB1LutHiNoS = tofp16 ? b.i32_val(0x46444240) : b.i32_val(0x40404040);
59-
60-
Type i32Ty = rewriter.getI32Type();
61-
auto permU32FnTy = LLVM::LLVMFunctionType::get(i32Ty, {i32Ty, i32Ty, i32Ty});
62-
LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(
63-
rewriter, upcastOp, "llvm.amdgcn.perm", permU32FnTy);
64-
65-
// Start with 8 mxfp4 elements in a single i32 register
66-
// | e7e6 | e5e4 | e3e2 | e1e0 |
67-
Value input = b.bitcast(packedVec, i32Ty);
68-
69-
// Step 1: extract EM bits for elements 0,2,4,6 and 1,3,5,7 respectively.
70-
// e2m1_6420_idx = | 0[0e6EM] | 0[0e4EM] | 0[0e2EM] | 0[0e0EM] |
71-
Value e2m1_6420_idx = b.and_(input, b.i32_val(0x07070707));
72-
// e2m1_7531_idx = | [0e7EM]0 | [0e5EM]0 | [0e3EM]0 | [0e1EM]0 |
73-
Value e2m1_7531_idx = b.and_(input, b.i32_val(0x70707070));
74-
// e2m1_7531_idx = | 0[0e7EM] | 0[0e5EM] | 0[0e3EM] | 0[0e1EM] |
75-
e2m1_7531_idx = b.lshr(e2m1_7531_idx, b.i32_val(4));
76-
77-
// Step 2: extract S bit for elements 0,2,4,6 and 1,3,5,7
78-
// s_6420 = | 0[e6S000] | 0[e4S000] | 0[e2S000] | 0[e0S000] |
79-
Value s_6420 = b.and_(input, b.i32_val(0x08080808));
80-
// s_6420 = | [e6S000]0 | [e4S000]0 | [e2S000]0 | [e0S000]0 |
81-
s_6420 = b.shl(s_6420, b.i32_val(4));
82-
// s_7531 = | [e7S000]0 | [e5S000]0 | [e3S000]0 | [e1S000]0 |
83-
Value s_7531 = b.and_(input, b.i32_val(0x80808080));
84-
85-
// Step 3: Upcast elements 0,2,4,6 to 4 16-bit elements
86-
// Select Byte #0. It's always 0 if upcasting to fp16.
87-
// resB0_6420 = | e6B0 | e4B0 | e2B0 | e0B0 |
88-
Value resB0_6420 = b.i32_val(0);
89-
if (!tofp16) {
90-
resB0_6420 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
91-
{resB0LutHi, resB0LutLo, e2m1_6420_idx})
92-
.getResult();
93-
}
94-
// Select Byte #1
95-
Value resB1NoS_6420 =
96-
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
97-
{resB1LutHiNoS, resB1LutLoNoS, e2m1_6420_idx})
98-
.getResult();
99-
// resB1_6420 = | e6B1 | e4B1 | e2B1 | e0B1 |
100-
Value resB1_6420 = b.or_(resB1NoS_6420, s_6420);
101-
// Construct 16-bit values of e0 and e2
102-
// res_20 = | e2B1 | e2B0 | e0B1 | e0B0 | = | e2_f16 | e0_f16 |
103-
Value res_20 =
104-
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
105-
{resB1_6420, resB0_6420, b.i32_val(0x05010400)})
106-
.getResult();
107-
// Construct 16-bit values of e4 and e6
108-
// res_64 = | e6B1 | e6B0 | e4B1 | e4B0 | = | e6_f16 | e4_f16 |
109-
Value res_64 =
110-
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
111-
{resB1_6420, resB0_6420, b.i32_val(0x07030602)})
112-
.getResult();
113-
114-
// Step 4: Upcast elements 1,3,5,7 to 4 16-bit elements
115-
// This is a copy of step 3 on different group of elements
116-
// Select Byte #0. It's always 0 if upcasting to fp16.
117-
// resB0_7531 = | e7B0 | e5B0 | e3B0 | e1B0 |
118-
Value resB0_7531 = b.i32_val(0);
119-
if (!tofp16) {
120-
resB0_7531 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
121-
{resB0LutHi, resB0LutLo, e2m1_7531_idx})
122-
.getResult();
123-
}
124-
// Select Byte #1
125-
Value resB1NoS_7531 =
126-
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
127-
{resB1LutHiNoS, resB1LutLoNoS, e2m1_7531_idx})
128-
.getResult();
129-
// resB1_7531 = | e7B1 | e5B1 | e3B1 | e1B1 |
130-
Value resB1_7531 = b.or_(resB1NoS_7531, s_7531);
131-
// Construct 16-bit values of e1 and e3
132-
// res_31 = | e3B1 | e3B0 | e1B1 | e1B0 | = | e3_f16 | e1_f16 |
133-
Value res_31 =
134-
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
135-
{resB1_7531, resB0_7531, b.i32_val(0x05010400)})
136-
.getResult();
137-
// Construct 16-bit values of e5 and e7
138-
// res_75 = | e7B1 | e7B0 | e5B1 | e5B0 | = | e7_f16 | e5_f16 |
139-
Value res_75 =
140-
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
141-
{resB1_7531, resB0_7531, b.i32_val(0x07030602)})
142-
.getResult();
143-
144-
// Step 5: Reorder 16-bit elements to be 0,1,2,3,4,5,6,7
145-
// res_10 = | e1_f16 | e0_f16 |
146-
Value res_10 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
147-
{res_31, res_20, b.i32_val(0x05040100)})
148-
.getResult();
149-
// res_32 = | e3_f16 | e2_f16 |
150-
Value res_32 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
151-
{res_31, res_20, b.i32_val(0x07060302)})
152-
.getResult();
153-
// res_54 = | e5_f16 | e4_f16 |
154-
Value res_54 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
155-
{res_75, res_64, b.i32_val(0x05040100)})
156-
.getResult();
157-
// res_76 = | e7_f16 | e6_f16 |
158-
Value res_76 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
159-
{res_75, res_64, b.i32_val(0x07060302)})
160-
.getResult();
161-
162-
return {res_10, res_32, res_54, res_76};
163-
}
164-
16526
SmallVector<Value, 8> upcastMxfp4_SW(RewriterBase &rewriter,
16627
amdgpu::UpcastMXFPOp upcastOp, bool toFp16,
16728
ArrayRef<Value> values, int idx) {

0 commit comments

Comments
 (0)