Skip to content

Commit 629fd50

Browse files
[WARNINGS] Emit warning for WGMMA fp8 dot when transposition prevents pipelining (#6875)
**TL;DR**: For fp8 WGMMA matmuls, If input tensors are not in a specific transposed format in global memory (row-major A, col-major B), pipelining will be disabled. Emit a warning for these cases. If you run an fp8 matmul (e.g. 03-matrix-multiplication) with the B matrix in row-major format (e.g. https://gist.github.com/davidberard98/21fcee4a46192a1a756a458dfc3669fe), and use MLIR_ENABLE_DIAGNOSTICS=warnings, then a warning like this one will be emitted: ``` /home/dberard/fbcode/scripts/dberard/triton/fp8_mm.py:171:35: warning: Warning: Forcing a different order [0, 1] on SMEM than the register order for the operand 1. Registers will be transposed before SMEM store and the pipelined load for this operand will be disabled, so poor performance is expected. accumulator = tl.dot(a, b, accumulator) ``` Since this is a user-facing restriction that has significant implications on the performance of fp8 matmuls, I think it makes sense to make this a warning. Note: This warning already exists for MMAv5; this PR just plumbs the required info into the getSharedMemoryMMAOperand function so that diagnostics can be emitted: https://github.com/triton-lang/triton/blob/7dc549208aa3ce30612fe884bc4723f95f4b40b1/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp#L188-L195
1 parent 76ed95b commit 629fd50

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,12 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
188188
if (newOrder != order && op) {
189189
op->emitWarning("Warning: Forcing a different order [")
190190
<< newOrder[0] << ", " << newOrder[1]
191-
<< "] on SMEM than the register order for the opreand " << opIdx
191+
<< "] on SMEM than the register order for the operand " << opIdx
192192
<< ". Registers will be transposed before SMEM store and the pipelined "
193193
"load for this operand will be disabled, so poor performance is "
194-
"expected.";
194+
"expected. Recommendation: consider transposing the operand in "
195+
"global "
196+
"memory to remove the need to transpose the tensor in registers.";
195197
}
196198

197199
Attribute SharedMemorySpace =
@@ -391,9 +393,14 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
391393
int bitwidth = getElementTypeOrSelf(a).getIntOrFloatBitWidth();
392394
a = getDotOperand(a, 0, bitwidth);
393395
} else {
394-
a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose);
396+
a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose,
397+
/*isMMAv5Fp4Padded=*/false,
398+
/*forceTranspose=*/false, dotOp);
395399
}
396-
b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose);
400+
b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose,
401+
/*isMMAv5Fp4Padded=*/false,
402+
/*forceTranspose=*/false, dotOp);
403+
397404
newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>(
398405
dotOp.getLoc(), newRetType, a, b, newAcc, nullptr,
399406
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false);

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul -verify-diagnostics=only-expected | FileCheck %s
22

33
// CHECK: #[[MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}>
44
// CHECK: #[[MMA1:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
@@ -526,3 +526,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
526526
tt.return %0 : tensor<128x256xf32, #blocked>
527527
}
528528
}
529+
530+
// -----
531+
532+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
533+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
534+
// CHECK-LABEL: hopper_fp8_non_transposed_b
535+
tt.func public @hopper_fp8_non_transposed_b(
536+
%operand0: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
537+
%operand1: tensor<128x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
538+
%out_ptrs: tensor<128x256x!tt.ptr<f32>, #blocked>) {
539+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
540+
// CHECK: ttng.warp_group_dot
541+
// expected-warning @below {{Forcing a different order}}
542+
%64 = tt.dot %operand0, %operand1, %cst, inputPrecision = tf32 {maxNumImpreciseAcc = 1073741824 : i32} : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
543+
tt.store %out_ptrs, %64 : tensor<128x256x!tt.ptr<f32>, #blocked>
544+
tt.return
545+
}
546+
}

0 commit comments

Comments
 (0)