Skip to content

Commit e40c213

Browse files
yiqian1antiagainst
andauthored
[AMD] Improve scaled dot with (b)f16 types on GFX950 (#7693)
For such cases we need to upcast mxfp into (b)f16 to utilize (b)f16 mfma intrinsics. But the upcasting can have native instruction support now. Co-authored-by: Lei Zhang <[email protected]>
1 parent 2e05786 commit e40c213

File tree

6 files changed

+278
-56
lines changed

6 files changed

+278
-56
lines changed

python/test/unit/language/test_core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4442,6 +4442,9 @@ def make_finite(x, dtype):
44424442
assert 'st.global.v4' in ptx
44434443
assert (re.search(r'(mma|wgmma.mma_async).sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.(f|bf)16.(f|bf)16', ptx)
44444444
or "tcgen05.mma.cta_group::1.kind::f16" in ptx)
4445+
if is_hip_cdna4() and normal_type in ["bf16", "fp16"]:
4446+
amdgcn = pgm.asm['amdgcn']
4447+
assert (re.search(r"v_cvt_scalef32_pk_.*?(fp4|fp8|bf8).*?op_sel", amdgcn))
44454448

44464449

44474450
@pytest.mark.interpreter

test/Conversion/amd/upcast_mxfp.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck --check-prefixes=GFX950 %s
2+
3+
// -----
4+
5+
// GFX950-LABEL: upcast_mxfp4
6+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
7+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
8+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 4096 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
9+
tt.func public @upcast_mxfp4(%arg0 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, %arg1 : tensor<32x2xi8, #blocked>) {
10+
// GFX950-DAG: %[[CST:.*]] = llvm.mlir.constant(23 : i32) : i32
11+
// GFX950-DAG: %[[ISCALE:.*]] = llvm.zext %{{.*}} : i8 to i32
12+
// GFX950: %[[INTS:.*]] = llvm.shl %[[ISCALE]], %[[CST]] : i32
13+
// GFX950: %[[SCALE:.*]] = llvm.bitcast %[[INTS]] : i32 to f32
14+
// GFX950: rocdl.cvt.scalef32.pk.bf16.fp4 %[[REG:.*]][0], %[[SCALE]] : vector<2xbf16>
15+
// GFX950: rocdl.cvt.scalef32.pk.bf16.fp4 %[[REG]][2], %[[SCALE]] : vector<2xbf16>
16+
// GFX950: rocdl.cvt.scalef32.pk.bf16.fp4 %[[REG]][1], %[[SCALE]] : vector<2xbf16>
17+
// GFX950: rocdl.cvt.scalef32.pk.bf16.fp4 %[[REG]][3], %[[SCALE]] : vector<2xbf16>
18+
%1 = amdgpu.upcast_mxfp %arg0, %arg1 fp_type = e2m1 {fastMath = false} : tensor<32x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #blocked> -> tensor<64x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
19+
tt.return
20+
}
21+
}
22+
23+
24+
// -----
25+
26+
// GFX950-LABEL: upcast_mxfp8
27+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
28+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
29+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 4096 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
30+
tt.func public @upcast_mxfp8(%arg0 : tensor<64x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, %arg1 : tensor<32x2xi8, #blocked>) {
31+
// GFX950-DAG: %[[CST:.*]] = llvm.mlir.constant(23 : i32) : i32
32+
// GFX950-DAG: %[[ISCALE:.*]] = llvm.zext %{{.*}} : i8 to i32
33+
// GFX950: %[[INTS:.*]] = llvm.shl %[[ISCALE]], %[[CST]] : i32
34+
// GFX950: %[[SCALE:.*]] = llvm.bitcast %[[INTS]] : i32 to f32
35+
// GFX950: rocdl.cvt.scalef32.pk.bf16.fp8 %[[REG:.*]][false], %[[SCALE]] : vector<2xbf16>
36+
// GFX950: rocdl.cvt.scalef32.pk.bf16.fp8 %[[REG]][true], %[[SCALE]] : vector<2xbf16>
37+
%1 = amdgpu.upcast_mxfp %arg0, %arg1 fp_type = e4m3 {fastMath = false} : tensor<64x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, tensor<32x2xi8, #blocked> -> tensor<64x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
38+
tt.return
39+
}
40+
}
41+
42+
// -----
43+
44+
// GFX950-LABEL: upcast_mxbf8
45+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
46+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
47+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 4096 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
48+
tt.func public @upcast_mxbf8(%arg0 : tensor<64x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, %arg1 : tensor<32x2xi8, #blocked>) {
49+
// GFX950-DAG: %[[CST:.*]] = llvm.mlir.constant(23 : i32) : i32
50+
// GFX950-DAG: %[[ISCALE:.*]] = llvm.zext %{{.*}} : i8 to i32
51+
// GFX950: %[[INTS:.*]] = llvm.shl %[[ISCALE]], %[[CST]] : i32
52+
// GFX950: %[[SCALE:.*]] = llvm.bitcast %[[INTS]] : i32 to f32
53+
// GFX950: rocdl.cvt.scalef32.pk.f16.bf8 %[[REG:.*]][false], %[[SCALE]] : vector<2xf16>
54+
// GFX950: rocdl.cvt.scalef32.pk.f16.bf8 %[[REG]][true], %[[SCALE]] : vector<2xf16>
55+
%1 = amdgpu.upcast_mxfp %arg0, %arg1 fp_type = e5m2 {fastMath = false} : tensor<64x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, tensor<32x2xi8, #blocked> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
56+
tt.return
57+
}
58+
}

test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,34 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
224224
}
225225
}
226226

227+
// -----
228+
229+
// CHECK-LABEL: mfma_dot_scaled_bf16_fp8e4
230+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
231+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
232+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
233+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
234+
tt.func public @mfma_dot_scaled_bf16_fp8e4(
235+
%arg0: tensor<32x64xbf16, #blocked2>,
236+
%arg1: tensor<64x32xf8E4M3FN, #blocked>,
237+
%arg2: tensor<32x2xi8, #blocked1>,
238+
%arg3: tensor<32x32x!tt.ptr<f32>, #blocked>
239+
) {
240+
// CHECK-NOT: tt.fp_to_fp
241+
// CHECK-NOT: tt.dot_scaled
242+
// CHECK: %[[A:.*]] = ttg.convert_layout %{{.*}} : tensor<32x64xbf16, #blocked{{.*}}> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
243+
// CHECK: %[[B:.+]] = ttg.convert_layout %{{.*}} : tensor<64x32xf8E4M3FN, #blocked{{.*}}> -> tensor<64x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
244+
// CHECK: %[[S:.+]] = ttg.convert_layout %{{.*}} : tensor<32x2xi8, #blocked{{.*}}> -> tensor<32x2xi8, #blocked{{.*}}>
245+
// CHECK: %[[UB:.+]] = amdgpu.upcast_mxfp %[[B]], %[[S]] fp_type = e4m3 {fastMath = false} : tensor<64x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, tensor<32x2xi8, #blocked{{.*}}> -> tensor<64x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
246+
// CHECK: %{{.*}} = tt.dot %[[A]], %[[UB]], %{{.*}} : tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<32x32xf32, #mma>
247+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
248+
%1 = tt.dot_scaled %arg0, %arg1 scale %arg2, %cst lhs = bf16 rhs = e4m3 {fastMath = false} : tensor<32x64xbf16, #blocked2> * tensor<64x32xf8E4M3FN, #blocked>, tensor<32x2xi8, #blocked1> -> tensor<32x32xf32, #blocked>
249+
tt.store %arg3, %1 : tensor<32x32x!tt.ptr<f32>, #blocked>
250+
tt.return
251+
}
252+
}
253+
254+
227255
// -----
228256

229257
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,11 @@ LogicalResult UpcastMXFPOp::verify() {
253253
Builder b(getContext());
254254
if (xTy.getElementType() != b.getBF16Type() &&
255255
xTy.getElementType() != b.getF16Type() &&
256-
xTy.getElementType() != b.getI8Type()) {
257-
return emitOpError(
258-
"element type of the first operand must be bf16/fp16 or i8");
256+
xTy.getElementType() != b.getI8Type() &&
257+
xTy.getElementType() != b.getType<Float8E4M3FNType>() &&
258+
xTy.getElementType() != b.getType<Float8E5M2Type>()) {
259+
return emitOpError("element type of the first operand must be bf16/fp16, "
260+
"OCP fp8/bf8 or i8");
259261
}
260262

261263
if (scaleTy.getElementType() != b.getI8Type()) {
@@ -328,27 +330,30 @@ UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
328330
Type outputElemType) {
329331
MLIRContext *ctx = inputTensor.getContext();
330332
auto xTy = inputTensor.getType();
331-
if (inputElemType != ScaleDotElemType::E2M1)
333+
if (!(inputElemType == ScaleDotElemType::E2M1 ||
334+
inputElemType == ScaleDotElemType::E4M3 ||
335+
inputElemType == ScaleDotElemType::E5M2))
332336
return xTy;
333337

338+
auto factor = inputElemType == ScaleDotElemType::E2M1 ? 2 : 1;
334339
auto xShape = xTy.getShape();
335340
auto newShape = llvm::to_vector(xShape);
336341
auto encoding = xTy.getEncoding();
337342
if (!encoding) {
338-
newShape.back() *= 2;
343+
newShape.back() *= factor;
339344
return RankedTensorType::get(xShape, outputElemType);
340345
}
341346

342347
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
343-
auto newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
344-
oldEncoding.getParent(),
345-
oldEncoding.getKWidth() * 2);
348+
auto newVEncoding = DotOperandEncodingAttr::get(
349+
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
350+
oldEncoding.getKWidth() * factor);
346351
// Figure out the K dimension for the input A/B, given that the return
347352
// type is upcasted A/B type so we need to update the proper dim size.
348353
const int opIdx = oldEncoding.getOpIdx();
349354
const bool hasBatch = xShape.size() == 3;
350355
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
351-
newShape[kIdx] *= 2;
356+
newShape[kIdx] *= factor;
352357
return RankedTensorType::get(newShape, outputElemType, newVEncoding);
353358
}
354359

0 commit comments

Comments
 (0)