Skip to content

Commit c80eef1

Browse files
authored
[BACKEND] Fix promoteOperand behavior in AccelerateMatmul for SM < 80 (#7158)
Dot op using MMA for compute capability < 80 has been deprecated. It falls back to the FMA path. In this path, `promoteOperand` used `triton::FpToFpOp` unconditionally, which supports `F8 <-> FP16, BF16, FP32, FP64` conversions. This change introduces an `ElementType` check in `promoteOperand`: if the operand’s element type is **not** FP8, it uses `arith::ExtFOp` instead of `triton::FpToFpOp`.
1 parent 762ace9 commit c80eef1

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Support/LogicalResult.h"
77
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
88
#include "triton/Analysis/Utility.h"
9+
#include "triton/Conversion/MLIRTypes.h"
910
#include "triton/Dialect/Triton/IR/Dialect.h"
1011
#include "triton/Dialect/Triton/IR/Utility.h"
1112
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
@@ -757,7 +758,12 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
757758
Type promotedType) {
758759
Type tensorPromotedType = cast<RankedTensorType>(operand.getType())
759760
.cloneWith(std::nullopt, promotedType);
760-
return builder.create<FpToFpOp>(loc, tensorPromotedType, operand);
761+
Type operandElType =
762+
cast<RankedTensorType>(operand.getType()).getElementType();
763+
if (type::isFloat8(operandElType)) {
764+
return builder.create<FpToFpOp>(loc, tensorPromotedType, operand);
765+
}
766+
return builder.create<arith::ExtFOp>(loc, tensorPromotedType, operand);
761767
}
762768

763769
// promote operands of dot op if the existing combination is not natively

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,3 +549,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
549549
tt.return
550550
}
551551
}
552+
553+
// -----
554+
555+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
556+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:75", "ttg.threads-per-warp" = 32 : i32} {
557+
// CHECK-LABEL: dot_fall_back_fma_before_ampere
558+
tt.func public @dot_fall_back_fma_before_ampere(%arg0: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<128x256x!tt.ptr<f32>, #blocked>) {
559+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
560+
// CHECK: %[[EXT0:.*]] = arith.extf %arg0
561+
// CHECK: %[[EXT1:.*]] = arith.extf %arg1
562+
// CHECK: %[[DOT:.*]] = tt.dot %[[EXT0]], %[[EXT1]]
563+
%0 = tt.dot %arg0, %arg1, %cst, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
564+
// CHECK: tt.store %arg2, %[[DOT]]
565+
tt.store %arg2, %0 : tensor<128x256x!tt.ptr<f32>, #blocked>
566+
tt.return
567+
}
568+
}

0 commit comments

Comments
 (0)