Skip to content

Commit ff63ee2

Browse files
authored
[NVIDIA] Add native FP4 scaled_dot for SM120 (#8494)
### Summary Implement native FP4 scaled matmul support for SM120 by replacing the previous decomposition fallback. - mxfp4 x mxfp4 - nvfp4 x nvfp4. For nvfp4, the scale needs to be e4m3 and the scale group size is 16 ### Benchmark E2E vLLM Benchmark: Llama3-8B-Instruct - in_len=1024 out_len=1024 batch_size=128 (5090 RTX) (Thanks to @mobicham, he conducted this benchmark) ``` current main Branch: mxfp4 x mxfp4: 61 sec This PR: mxfp4 x mxfp4: 33 sec nvfp4 x nvfp4: 34.5 sec ````
1 parent b3e233e commit ff63ee2

File tree

5 files changed

+62
-32
lines changed

5 files changed

+62
-32
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,15 +1496,9 @@ LinearLayout chooseScaledWmmaScaleLayout(
14961496

14971497
// PTX ISA - Warp-level MMA Block Scaling
14981498
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1499-
//
15001499
// This function generates layouts for scale tensors used in scaled dot
15011500
// operations.
1502-
//
1503-
// Supported .kind x scale_vec_size:
1504-
// mxf8f6f4 with UE8M0 scales -> .scale_vec::1X
1505-
//
15061501
// Implementation notes:
1507-
// - We support only scale_vec::1X for now.
15081502
// - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
15091503
// 0)
15101504
// - We choose a fixed byte selector for A (byte-id-a = 0) and B (byte-id-b =

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -672,15 +672,6 @@ class ScaledBlockedToMMA : public mlir::OpRewritePattern<triton::DotScaledOp> {
672672
if (numCTAs != 1) {
673673
return failure();
674674
}
675-
676-
// TODO: support mxfp4 variants.
677-
if (!((dotOp.getAElemType() == ScaleDotElemType::E5M2 ||
678-
dotOp.getAElemType() == ScaleDotElemType::E4M3) &&
679-
(dotOp.getBElemType() == ScaleDotElemType::E5M2 ||
680-
dotOp.getBElemType() == ScaleDotElemType::E4M3))) {
681-
return rewriter.notifyMatchFailure(dotOp, "only E5M2/E4M3 is supported");
682-
}
683-
684675
// Skip if any scale is missing. This pattern requires both scales.
685676
if (!dotOp.getAScale() || !dotOp.getBScale())
686677
return failure();
@@ -759,7 +750,6 @@ class ScaledBlockedToMMA : public mlir::OpRewritePattern<triton::DotScaledOp> {
759750
};
760751

761752
const auto mmaWarps = mmaResult.mmaEnc.getWarpsPerCTA(); // [wM, wN]
762-
763753
// Convert scales to Linear layout
764754
auto convertScale = [&](Value scale, int opIdx) -> Value {
765755
auto ty = cast<RankedTensorType>(scale.getType());

python/test/unit/language/test_matmul.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,8 +1031,10 @@ def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, with_a_sc
10311031
if is_cuda():
10321032
if scale_type == "float8_e4m3fn" and not pack_along_k:
10331033
pytest.skip("Packing along K is required for float8_e4m3fn")
1034-
if torch.cuda.get_device_capability()[0] != 10:
1035-
pytest.skip("Requires compute capability == 10")
1034+
if torch.cuda.get_device_capability()[0] != 10 and torch.cuda.get_device_capability()[0] != 12:
1035+
pytest.skip("Requires compute capability == 10 or 12")
1036+
if torch.cuda.get_device_capability()[0] == 12 and pack_along_k is False:
1037+
pytest.skip("Packing along M, N is not supported on SM120")
10361038
if not (with_a_scale and with_b_scale):
10371039
pytest.skip("None aScale/bScale is only tested on AMD backend for now")
10381040
elif is_hip():

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -699,21 +699,18 @@ module attributes {"ttg.target" = "cuda:120", "ttg.num-ctas" = 1 : i32, "ttg.num
699699

700700
// -----
701701

702-
// Verify that for SM_120 with FP4 inputs, tt.dot_scaled is decomposed into:
703-
// 1. ttg.fp4_to_fp for unpacking FP4 values
704-
// 2. Scale application with arith.mulf
705-
// 3. Regular tt.dot operation with MMA encoding
702+
// Verify that for SM_120 with FP4 inputs, tt.dot_scaled is preserved and
703+
// scales are converted to linear layout for hardware acceleration.
706704

707705
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
708706
#blocked2_k = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
709707

710708
module attributes {"ttg.target" = "cuda:120", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
711-
// CHECK-LABEL: @sm120_dot_scaled_fp4_fallback
712-
// CHECK-NOT: tt.dot_scaled
713-
// CHECK: ttg.fp4_to_fp
714-
// CHECK: tt.dot
715-
// CHECK: #mma
716-
tt.func public @sm120_dot_scaled_fp4_fallback(
709+
// CHECK-LABEL: @sm120_dot_scaled_fp4_native
710+
// CHECK-DAG: tt.dot_scaled
711+
// CHECK-DAG: #linear
712+
// CHECK-DAG: #linear1
713+
tt.func public @sm120_dot_scaled_fp4_native(
717714
%a: tensor<128x32xi8, #blocked2_k>,
718715
%scale_a: tensor<128x2xi8, #blocked2>,
719716
%b: tensor<32x128xi8, #blocked2>,

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,9 @@ enum class TensorCoreType : uint8_t {
296296
FP32_FP8E4M3FN_FP8E5M2_FP32_SCALE_VEC_1X,
297297
FP32_FP8E4M3FN_FP8E4M3FN_FP32_SCALE_VEC_1X,
298298
//
299+
FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X,
300+
FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X,
301+
//
299302
NOT_APPLICABLE,
300303
};
301304

@@ -339,6 +342,8 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) {
339342
case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32_SCALE_VEC_1X:
340343
case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32_SCALE_VEC_1X:
341344
case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32_SCALE_VEC_1X:
345+
case TensorCoreType::FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X:
346+
case TensorCoreType::FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X:
342347
return fp32x4Ty;
343348
default:
344349
llvm::report_fatal_error("Unsupported mma type found");
@@ -367,6 +372,15 @@ static TensorCoreType getMmaTypeDotScaled(DotScaledOp op, RankedTensorType aTy,
367372
llvm::isa<Float8E4M3FNType>(bTy.getElementType())) {
368373
return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32_SCALE_VEC_1X;
369374
}
375+
if (op.getBElemType() == ScaleDotElemType::E2M1 &&
376+
op.getAElemType() == ScaleDotElemType::E2M1) {
377+
if (isa<mlir::Float8E4M3FNType>(
378+
op.getBScale().getType().getElementType())) {
379+
return TensorCoreType::FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X;
380+
} else {
381+
return TensorCoreType::FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X;
382+
}
383+
}
370384
}
371385
return TensorCoreType::NOT_APPLICABLE;
372386
}
@@ -493,6 +507,14 @@ inline static const std::map<TensorCoreType, std::string> mmaInstrPtxScaled = {
493507
"mma.sync.aligned.m16n8k32.row.col."
494508
"kind::mxf8f6f4.block_scale.scale_vec::"
495509
"1X.f32.e4m3.e4m3.f32.ue8m0"},
510+
{TensorCoreType::FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X,
511+
"mma.sync.aligned.m16n8k64.row.col."
512+
"kind::mxf4nvf4.block_scale.scale_vec::"
513+
"2X.f32.e2m1.e2m1.f32.ue8m0"},
514+
{TensorCoreType::FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X,
515+
"mma.sync.aligned.m16n8k64.row.col."
516+
"kind::mxf4nvf4.block_scale.scale_vec::"
517+
"4X.f32.e2m1.e2m1.f32.ue4m3"},
496518
};
497519

498520
static void callMmaTuringInt8(PTXBuilder &builder, int b,
@@ -890,13 +912,12 @@ LogicalResult convertMMADotScaled(triton::DotScaledOp op,
890912
TensorCoreType mmaType =
891913
getMmaTypeDotScaled(op, aTensorTy, bTensorTy, dTensorTy);
892914

893-
NumRegisters numRegisters = {2, 1, 2};
894-
895915
SmallVector<Value> unpackedAScale =
896916
unpackLLElements(op.getLoc(), adaptor.getAScale(), rewriter);
897917
SmallVector<Value> unpackedBScale =
898918
unpackLLElements(op.getLoc(), adaptor.getBScale(), rewriter);
899919

920+
NumRegisters numRegisters = {2, 1, 2};
900921
EmitMmaCallback emit = [&](PTXBuilder &builder, int b, int m, int n, int k,
901922
mlir::triton::PTXInstr &mma, unsigned numMmaRets,
902923
unsigned colsPerThread, unsigned batchOffset,
@@ -906,8 +927,34 @@ LogicalResult convertMMADotScaled(triton::DotScaledOp op,
906927
auto tb = TritonLLVMOpBuilder(op.getLoc(), rewriter);
907928
auto i32 = IntegerType::get(op->getContext(), 32);
908929

909-
Value aScaleValue = tb.zext(i32, unpackedAScale[m * repK + k]);
910-
Value bScaleValue = tb.zext(i32, unpackedBScale[n * repK + k]);
930+
auto packElements = [&](ArrayRef<Value> bytes, int loc,
931+
int numBytes) -> Value {
932+
Value packed = tb.zext(i32, bytes[loc]);
933+
for (int i = 1; i < numBytes; ++i) {
934+
Value byte = tb.zext(i32, bytes[loc + i]);
935+
Value shifted = tb.shl(byte, tb.i32_val(i * 8));
936+
packed = tb.or_(packed, shifted);
937+
}
938+
return packed;
939+
};
940+
941+
int scaleVecMode;
942+
if (mmaInstrPtxScaled.at(mmaType).find("1X") != std::string::npos) {
943+
scaleVecMode = 1;
944+
} else if (mmaType ==
945+
TensorCoreType::FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X) {
946+
scaleVecMode = 2;
947+
} else if (mmaType == TensorCoreType::FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X) {
948+
scaleVecMode = 4;
949+
} else {
950+
llvm_unreachable("Unsupported scale vector mode!");
951+
}
952+
Value aScaleValue =
953+
packElements(unpackedAScale, m * repK * scaleVecMode + k * scaleVecMode,
954+
scaleVecMode);
955+
Value bScaleValue =
956+
packElements(unpackedBScale, n * repK * scaleVecMode + k * scaleVecMode,
957+
scaleVecMode);
911958

912959
BaseOffset base{numRegisters.m * m, numRegisters.n * n, numRegisters.k * k};
913960
callMmaScaled(builder, b, base, mma, numMmaRets, colsPerThread, aTable,

0 commit comments

Comments
 (0)