Skip to content

Commit d0ef9d3

Browse files
authored
Fix handling of 3D DotOp with M=1 (#8561)
We currently always look at index 1 in the shape to determine if it is an outer product. However, for 3D operands, we actually have to look at index 2. This causes us to incorrectly crash when supplied a 3D operand where the first dimension is 1. Eliminate the isOuter check entirely, since my understanding is that: 1. it is purely defensive, since we will just crash when isOuter is true and mmaLayout is non-null. 2. it does not disqualify all the invalid K values, so it might be confusing.
1 parent 7c32dad commit d0ef9d3

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2645,3 +2645,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
26452645
tt.return
26462646
}
26472647
}
2648+
2649+
// -----
2650+
2651+
// We had a bug where DotOp lowering treated any input where shape[1] == 1 as an
2652+
// outer product and rejected it. This was incorrect in 3D tensors, since
2653+
// the dimension to look at would have been shape[2].
2654+
2655+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [32, 1, 1], instrShape = [1, 16, 8]}>
2656+
#dot_operand_a = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>
2657+
#dot_operand_b = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>
2658+
2659+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
2660+
// CHECK-LABEL: batched_dot_3d
2661+
tt.func public @batched_dot_3d(
2662+
%arg0: tensor<32x1x32xf16, #dot_operand_a>,
2663+
%arg1: tensor<32x32x32xf16, #dot_operand_b>
2664+
) {
2665+
%cst = arith.constant dense<0.000000e+00> : tensor<32x1x32xf32, #mma>
2666+
// CHECK: llvm.inline_asm
2667+
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
2668+
%result = tt.dot %arg0, %arg1, %cst, inputPrecision = tf32 :
2669+
tensor<32x1x32xf16, #dot_operand_a> * tensor<32x32x32xf16, #dot_operand_b> -> tensor<32x1x32xf32, #mma>
2670+
tt.return
2671+
}
2672+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,9 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
6060
Value A = op.getA();
6161
Value D = op.getResult();
6262

63-
// Here we assume the DotOp's operands always comes from shared memory.
64-
auto AShapePerCTA = getShapePerCTA(A.getType());
65-
size_t reduceAxis = 1;
66-
unsigned K = AShapePerCTA[reduceAxis];
67-
bool isOuter = K == 1;
68-
6963
NvidiaMmaEncodingAttr mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(
7064
cast<RankedTensorType>(D.getType()).getEncoding());
71-
if (!isOuter && mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) {
65+
if (mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) {
7266
if (mmaLayout.getVersionMajor() == 2) {
7367
bool isHopperF64 =
7468
computeCapability == 90 &&
@@ -106,14 +100,8 @@ struct WarpGroupDotOpConversion
106100
Value A = op.getA();
107101
TypedValue<RankedTensorType> D = op.getResult();
108102

109-
// Here we assume the DotOp's operands always comes from shared memory.
110-
auto AShapePerCTA = getShapePerCTA(A.getType());
111-
size_t reduceAxis = 1;
112-
unsigned K = AShapePerCTA[reduceAxis];
113-
bool isOuter = K == 1;
114-
115103
auto mmaLayout = cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
116-
if (!isOuter && supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) {
104+
if (supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) {
117105
return convertWGMMA(op, adaptor, getTypeConverter(), rewriter,
118106
getThreadId(rewriter, loc));
119107
}

0 commit comments

Comments
 (0)