Skip to content

Commit 299b3bb

Browse files
[BACKEND] Don't promote fp8 MMAv2 dot inputs for sm120 (#7409)
Fixes #7188 This speeds up fp8 matmuls on consumer blackwell (RTX 50xx series) by ~1.9x on large matmuls. sm>=89 supports MMAv2 with fp8 operands, but prior to this PR, Triton was only using this on sm==89; on other architectures, fp8 inputs would be promoted to fp16 and the mma would be executed in fp16. This PR causes the the fp8->fp16 promotion step to be skipped on any architecture >= 89. It also adds more mma variants to support f8 operands and f16 results, which were previously supported via the `FP16_FP16_FP16_FP16` variant. Evidence that we should be able to use fp8 operands to mmav2 on any architecture >= 89: In PTX docs https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma, under the "Target ISA Notes" section, see that the e4m3 and e5m2 are supported on sm_89 or higher (and don't require the "a" suffix, which would indicate that the support is non-backward-compatible). Perf improvement verified on a 5070 Ti using 03-matrix-multiplication.py (below are flops measurements on large MNK sizes): Before: ``` matmul-performance-fp8: M N K Triton ... 26 3584.0 3584.0 3584.0 101.256071 27 3712.0 3712.0 3712.0 99.947313 28 3840.0 3840.0 3840.0 101.182062 29 3968.0 3968.0 3968.0 101.771419 30 4096.0 4096.0 4096.0 101.206889 ``` After: ``` matmul-performance-fp8: M N K Triton ... 26 3584.0 3584.0 3584.0 191.309345 27 3712.0 3712.0 3712.0 190.280662 28 3840.0 3840.0 3840.0 195.316740 29 3968.0 3968.0 3968.0 194.305628 30 4096.0 4096.0 4096.0 193.258070 ```
1 parent 5949ee8 commit 299b3bb

File tree

4 files changed

+97
-3
lines changed

4 files changed

+97
-3
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,14 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
775775
return builder.create<arith::ExtFOp>(loc, tensorPromotedType, operand);
776776
}
777777

778+
static bool mmav2SupportsFp8Operands(int computeCapability) {
779+
// promote operands for sm < 89 since fp8 mma is not natively supported
780+
// although PTX instructions for mma v2 w/ fp8 operands exist for sm90 and
781+
// sm100, they are emulated as fp16 upcasts + fp16 HMMA in SASS. sm120 has
782+
// hardware support for fp8 operands w/ mmav2.
783+
return computeCapability == 89 || computeCapability == 120;
784+
}
785+
778786
// promote operands of dot op if the existing combination is not natively
779787
// supported.
780788
static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
@@ -787,10 +795,10 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
787795
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
788796
if (mmaLayout) {
789797
bool isNativeFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
790-
// promote operands for sm < 89 since fp8 mma is not natively supported
791-
// promote operands for sm >= 90 when mma is not v3
798+
// promote to f16 unless there's hardware support for fp8 operands
792799
if (!isNativeFP8 ||
793-
(isNativeFP8 && (computeCapability == 89 || mmaLayout.isHopper())))
800+
(isNativeFP8 && (mmav2SupportsFp8Operands(computeCapability) ||
801+
mmaLayout.isHopper())))
794802
return;
795803
promoteType = builder.getF16Type();
796804
} else {

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,6 +2257,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
22572257

22582258
// -----
22592259

2260+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
2261+
module attributes {"ttg.num-warps" = 8 : i32, ttg.target = "cuda:120"} {
2262+
// CHECK-LABEL: mmav2_e5m2_e5m2_fp16
2263+
tt.func public @mmav2_e5m2_e5m2_fp16(%arg0: tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, %arg2: tensor<32x32xf16, #mma>) {
2264+
// CHECK: mma.{{.*}}.col.f16.e5m2.e5m2.f16
2265+
%0 = tt.dot %arg0, %arg1, %arg2 {maxNumImpreciseAcc = 1073741824 : i32} : tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf16, #mma>
2266+
tt.return
2267+
}
2268+
2269+
// CHECK-LABEL: mmav2_e5m2_e4m3_fp16
2270+
tt.func public @mmav2_e5m2_e4m3_fp16(%arg0: tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, %arg2: tensor<32x32xf16, #mma>) {
2271+
// CHECK: mma.{{.*}}.col.f16.e5m2.e4m3.f16
2272+
%0 = tt.dot %arg0, %arg1, %arg2 {maxNumImpreciseAcc = 1073741824 : i32} : tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf16, #mma>
2273+
tt.return
2274+
}
2275+
2276+
// CHECK-LABEL: mmav2_e4m3_e5m2_fp16
2277+
tt.func public @mmav2_e4m3_e5m2_fp16(%arg0: tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, %arg2: tensor<32x32xf16, #mma>) {
2278+
// CHECK: mma.{{.*}}.col.f16.e4m3.e5m2.f16
2279+
%0 = tt.dot %arg0, %arg1, %arg2 {maxNumImpreciseAcc = 1073741824 : i32} : tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf16, #mma>
2280+
tt.return
2281+
}
2282+
2283+
// CHECK-LABEL: mmav2_e4m3_e4m3_fp16
2284+
tt.func public @mmav2_e4m3_e4m3_fp16(%arg0: tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, %arg2: tensor<32x32xf16, #mma>) {
2285+
// CHECK: mma.{{.*}}.col.f16.e4m3.e4m3.f16
2286+
%0 = tt.dot %arg0, %arg1, %arg2 {maxNumImpreciseAcc = 1073741824 : i32} : tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf16, #mma>
2287+
tt.return
2288+
}
2289+
}
2290+
2291+
// -----
2292+
22602293
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [4, 4, 2], warpsPerCTA = [8, 1, 1], order = [2, 1, 0]}>
22612294
#linear = #ttg.linear<{register = [[0, 0], [0, 0], [0, 0], [0, 0]], lane = [[0, 0], [0, 1], [0, 2], [1, 0], [2, 0]], warp = [[4, 0], [8, 0], [16, 0]], block = []}>
22622295

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
562562
}
563563
}
564564

565+
// -----
566+
567+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
568+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
569+
#blocked2 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
570+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} {
571+
// CHECK-LABEL: sm120_fp8_dot
572+
tt.func public @sm120_fp8_dot(%arg0: tensor<128x256xf32, #blocked>, %arg1: tensor<128x128x!tt.ptr<f8E4M3FN>, #blocked1>, %arg2: tensor<128x256x!tt.ptr<f8E4M3FN>, #blocked2>, %arg3: tensor<128x128xi1, #blocked1>, %arg4: tensor<128x256xi1, #blocked2>) -> tensor<128x256xf32, #blocked> {
573+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf8E4M3FN, #blocked2>
574+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf8E4M3FN, #blocked1>
575+
%0 = tt.load %arg1, %arg3, %cst_0 : tensor<128x128x!tt.ptr<f8E4M3FN>, #blocked1>
576+
%1 = tt.load %arg2, %arg4, %cst : tensor<128x256x!tt.ptr<f8E4M3FN>, #blocked2>
577+
%2 = ttg.convert_layout %0 : tensor<128x128xf8E4M3FN, #blocked1> -> tensor<128x128xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
578+
%3 = ttg.convert_layout %1 : tensor<128x256xf8E4M3FN, #blocked2> -> tensor<128x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
579+
// CHECK: {{.*}} = tt.dot {{.*}} tensor<128x128xf8E4M3FN
580+
%4 = tt.dot %2, %3, %arg0, inputPrecision = tf32 : tensor<128x128xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
581+
tt.return %4 : tensor<128x256xf32, #blocked>
582+
}
583+
}
584+
585+
565586
// -----
566587

567588
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,16 @@ enum class TensorCoreType : uint8_t {
257257
FP32_BF16_BF16_FP32,
258258
FP32_TF32_TF32_FP32,
259259
FP16_FP16_FP16_FP16,
260+
// fp32 accumulator, fp8 operand
260261
FP32_FP8E5M2_FP8E5M2_FP32,
261262
FP32_FP8E5M2_FP8E4M3FN_FP32,
262263
FP32_FP8E4M3FN_FP8E5M2_FP32,
263264
FP32_FP8E4M3FN_FP8E4M3FN_FP32,
265+
// fp16 accumulator, fp8 operand
266+
FP16_FP8E5M2_FP8E5M2_FP16,
267+
FP16_FP8E5M2_FP8E4M3FN_FP16,
268+
FP16_FP8E4M3FN_FP8E5M2_FP16,
269+
FP16_FP8E4M3FN_FP8E4M3FN_FP16,
264270
// integer tensor core instr
265271
INT32_INT1_INT1_INT32, // Not implemented
266272
INT32_INT4_INT4_INT32, // Not implemented
@@ -298,6 +304,11 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) {
298304
case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32:
299305
case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32:
300306
return fp32x4Ty;
307+
case TensorCoreType::FP16_FP8E5M2_FP8E5M2_FP16:
308+
case TensorCoreType::FP16_FP8E5M2_FP8E4M3FN_FP16:
309+
case TensorCoreType::FP16_FP8E4M3FN_FP8E5M2_FP16:
310+
case TensorCoreType::FP16_FP8E4M3FN_FP8E4M3FN_FP16:
311+
return fp16x2Pack2Ty;
301312
case TensorCoreType::INT32_INT8_INT8_INT32:
302313
return i32x4Ty;
303314
case TensorCoreType::FP64_FP64_FP64_FP64:
@@ -341,6 +352,18 @@ static TensorCoreType getMmaType(triton::DotOp op) {
341352
} else if (dTy.getElementType().isF16()) {
342353
if (aTy.getElementType().isF16() && bTy.getElementType().isF16())
343354
return TensorCoreType::FP16_FP16_FP16_FP16;
355+
if (llvm::isa<Float8E5M2Type>(aTy.getElementType()) &&
356+
llvm::isa<Float8E5M2Type>(bTy.getElementType()))
357+
return TensorCoreType::FP16_FP8E5M2_FP8E5M2_FP16;
358+
if (llvm::isa<Float8E5M2Type>(aTy.getElementType()) &&
359+
llvm::isa<Float8E4M3FNType>(bTy.getElementType()))
360+
return TensorCoreType::FP16_FP8E5M2_FP8E4M3FN_FP16;
361+
if (llvm::isa<Float8E4M3FNType>(aTy.getElementType()) &&
362+
llvm::isa<Float8E5M2Type>(bTy.getElementType()))
363+
return TensorCoreType::FP16_FP8E4M3FN_FP8E5M2_FP16;
364+
if (llvm::isa<Float8E4M3FNType>(aTy.getElementType()) &&
365+
llvm::isa<Float8E4M3FNType>(bTy.getElementType()))
366+
return TensorCoreType::FP16_FP8E4M3FN_FP8E4M3FN_FP16;
344367
} else if (dTy.getElementType().isF64()) {
345368
if (aTy.getElementType().isF64() && bTy.getElementType().isF64())
346369
return TensorCoreType::FP64_FP64_FP64_FP64;
@@ -387,6 +410,15 @@ inline static const std::map<TensorCoreType, std::string> mmaInstrPtxAmpere = {
387410
{TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32,
388411
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"},
389412

413+
{TensorCoreType::FP16_FP8E5M2_FP8E5M2_FP16,
414+
"mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16"},
415+
{TensorCoreType::FP16_FP8E5M2_FP8E4M3FN_FP16,
416+
"mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16"},
417+
{TensorCoreType::FP16_FP8E4M3FN_FP8E5M2_FP16,
418+
"mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16"},
419+
{TensorCoreType::FP16_FP8E4M3FN_FP8E4M3FN_FP16,
420+
"mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16"},
421+
390422
{TensorCoreType::FP64_FP64_FP64_FP64,
391423
"mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64"},
392424
};

0 commit comments

Comments
 (0)