Skip to content

Commit 53e6e9e

Browse files
authored
[BACKEND] LL for ldmatrix part3 - ldmatrix.x2/x1 for small tiles (#5703)
1 parent b1301d6 commit 53e6e9e

File tree

3 files changed

+56
-104
lines changed

3 files changed

+56
-104
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,7 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
11011101
auto rank = shape.size();
11021102
auto opIdx = dot.getOpIdx();
11031103
int kDim = (opIdx == 0) ? rank - 1 : rank - 2;
1104+
int nonKDim = (opIdx == 0) ? rank - 2 : rank - 1;
11041105

11051106
StringAttr kReg = S("register");
11061107
StringAttr kLane = S("lane");
@@ -1117,8 +1118,11 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
11171118
auto reg = 1 << logReg;
11181119
basesReg.push_back({0, reg});
11191120
}
1120-
std::vector<std::vector<int>> basesLane = {{1, 0}, {2, 0}, {4, 0}};
1121-
int numTileCols;
1121+
std::vector<std::vector<int>> basesLane = {
1122+
{1, 0}, {2, 0}, {4, 0}, {0, 0}, {0, 0}};
1123+
bool kX2 = shape[kDim] > 8 * 16 / elemBitWidth;
1124+
bool kX4 = shape[kDim] > 16 * 16 / elemBitWidth;
1125+
bool nonKX2 = shape[nonKDim] > 8;
11221126
// Construct a tile consisting of 4 8x8x16bits sub-tiles to use ldmatrix
11231127
// efficiently. opIdx=0 and opIdx=1 are handled differently.
11241128
if (opIdx == 0) {
@@ -1131,13 +1135,16 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
11311135
if (needTrans) {
11321136
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
11331137
"supported in the transposed mode");
1134-
basesLane.push_back({0, 8});
1135-
basesLane.push_back({8, 0});
1138+
if (nonKX2)
1139+
basesLane[3] = {0, 8};
1140+
if (kX2)
1141+
basesLane[4] = {8 * 16 / elemBitWidth, 0};
11361142
} else {
1137-
basesLane.push_back({8, 0});
1138-
basesLane.push_back({0, 8 * 16 / elemBitWidth});
1143+
if (nonKX2)
1144+
basesLane[3] = {8, 0};
1145+
if (kX2)
1146+
basesLane[4] = {0, 8 * 16 / elemBitWidth};
11391147
}
1140-
numTileCols = 16 * 16 / elemBitWidth;
11411148
} else {
11421149
// The matrix elements of thread 0 are distributed in the following pattern
11431150
// (fp16):
@@ -1147,14 +1154,20 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
11471154
if (needTrans) {
11481155
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
11491156
"supported in the transposed mode");
1150-
basesLane.push_back({8, 0});
1151-
basesLane.push_back({16, 0});
1157+
if (kX2)
1158+
basesLane[3] = {8, 0};
1159+
if (kX4)
1160+
basesLane[4] = {16, 0};
11521161
} else {
1153-
basesLane.push_back({0, 8 * 16 / elemBitWidth});
1154-
basesLane.push_back({0, 16 * 16 / elemBitWidth});
1162+
if (kX2)
1163+
basesLane[3] = {0, 8 * 16 / elemBitWidth};
1164+
if (kX4)
1165+
basesLane[4] = {0, 16 * 16 / elemBitWidth};
11551166
}
1156-
numTileCols = 32 * 16 / elemBitWidth;
11571167
}
1168+
int numTileCols =
1169+
(8 * 16 / elemBitWidth)
1170+
<< (static_cast<int>(kX2) + static_cast<int>(kX4 && opIdx == 1));
11581171
// Expand the `register` dimension so the size of columns matches `K`.
11591172
auto layout =
11601173
LinearLayout({{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}},

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -862,8 +862,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
862862
tt.func @convert_dot_ldmatrix(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
863863
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
864864
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
865-
// CHECK: nvgpu.ldmatrix
866-
// CHECK: nvgpu.ldmatrix
865+
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
866+
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
867+
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
867868
// CHECK-NOT: nvgpu.ldmatrix
868869
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
869870
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
@@ -892,8 +893,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
892893
tt.func @convert_dot_ldmatrix_swizzle(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
893894
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
894895
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
895-
// CHECK: nvgpu.ldmatrix
896-
// CHECK: nvgpu.ldmatrix
896+
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
897+
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
898+
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
897899
// CHECK-NOT: nvgpu.ldmatrix
898900
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
899901
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
@@ -974,7 +976,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
974976
tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) {
975977
%AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
976978
%BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
977-
// CHECK: nvgpu.ldmatrix
979+
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
980+
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
981+
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
982+
// CHECK-NOT: nvgpu.ldmatrix
978983
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a>
979984
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b>
980985
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 21 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -52,98 +52,28 @@ struct LocalLoadOpConversion
5252
auto kOrder = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2;
5353
auto nonKOrder = dotEnc.getOpIdx() == 0 ? rank - 2 : rank - 1;
5454
auto needTrans = kOrder != sharedEnc.getOrder()[0];
55-
// Limitation 1: Cannot use ldmatrix if we need to transpose a non-fp16
56-
// matrix
57-
// Limitation 2: If kWidth is greater than the vector width of the dot
58-
// operands of MMA, we don't use ldmatrix
59-
// Limitation 3 [TODO: remove]: Shared memory with leading offset is not
60-
// supported yet
61-
auto canUseLdmatrixLegacy =
55+
// Limitation 1 [TODO: remove]: Check LL bases to verify register and
56+
// address alignment
57+
auto canUseLdmatrix =
6258
(kWidth == vecWidth) && (!sharedEnc.getHasLeadingOffset());
63-
if (mmaEnc.isHopper()) {
64-
// Limitation 4 [TODO: remove]:
65-
// I think we should be able to remove this condition, but it's here
66-
// as the legacy ldmatrix path does not support it
67-
canUseLdmatrixLegacy &= srcTy.getElementTypeBitWidth() * kWidth == 32 &&
68-
dotEnc.getOpIdx() == 0;
69-
}
70-
// Limitation 5: If we perform swizzling, it must be done within a single
71-
// ldmatrix tile
72-
auto maxPhase = sharedEnc.getMaxPhase();
73-
auto perPhase = sharedEnc.getPerPhase();
74-
auto vecSize = sharedEnc.getVec();
75-
canUseLdmatrixLegacy &=
76-
(maxPhase == 1) ||
77-
((maxPhase / perPhase <= 8) && (vecSize * bitwidth >= 8 * 16));
59+
canUseLdmatrix &= (sharedEnc.getMaxPhase() == 1) ||
60+
(sharedEnc.getVec() * bitwidth >= 8 * 16);
7861
auto shape = srcTy.getShape();
79-
auto allocShape = srcTy.getAllocShape();
80-
// Limitation 6 [TODO: remove]: Only support 2d matrices now but we should
62+
// Limitation 2 [TODO: remove]: Only support 2d matrices now but we should
8163
// be able to support 3D minor changes
82-
auto canUseLdmatrixLL = (bitwidth <= 16 || (!needTrans)) &&
83-
shape.size() <= 2 && canUseLdmatrixLegacy;
84-
canUseLdmatrixLegacy &=
85-
(bitwidth == 16 || (!needTrans)) && shape.size() <= 2;
86-
if (dotEnc.getOpIdx() == 0) {
87-
canUseLdmatrixLL &=
88-
shape[kOrder] >= (16 * 16 / bitwidth) && shape[nonKOrder] >= 16;
89-
} else {
90-
// Limitation 8 [TODO: remove]: Due to the use of ldmatrix.x4, we need
91-
// to read 4 tiles. For opIdx=1, a single warp load four consecutive
92-
// tiles along the K dimension, so the minimum K size is 4 * 8 = 32.
93-
// The legacy path doesn't have this limitation because it reads
94-
// duplicated elements from shared memory and throw them away.
95-
// It might be better to use ldmatrix.x2 in such a case instead of
96-
// abandoning elements.
97-
canUseLdmatrixLL &=
98-
shape[kOrder] >= (32 * 16 / bitwidth) && shape[nonKOrder] >= 16;
99-
}
100-
// Limitation 9 [TODO: remove]:
101-
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
102-
// though. Remove this constraint after all other limitations have been
103-
// resolved
104-
canUseLdmatrixLegacy &=
105-
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
106-
if (canUseLdmatrixLL) {
64+
canUseLdmatrix &= (bitwidth <= 16 || !needTrans) && shape.size() <= 2;
65+
// Limitation 3: Minimum tile size (8)x(8x16bits)
66+
canUseLdmatrix &=
67+
shape[kOrder] >= (8 * 16 / bitwidth) && shape[nonKOrder] >= 8;
68+
if (canUseLdmatrix) {
10769
return lowerSharedToDotOperandLL(op, adaptor, getTypeConverter(),
10870
rewriter);
109-
} else if (canUseLdmatrixLegacy) {
110-
return lowerSharedToDotOperandLegacy(op, adaptor, getTypeConverter(),
111-
rewriter);
11271
}
11372
}
11473
return failure();
11574
}
11675

11776
private:
118-
LogicalResult
119-
lowerSharedToDotOperandLegacy(triton::gpu::LocalLoadOp op,
120-
triton::gpu::LocalLoadOpAdaptor adaptor,
121-
const LLVMTypeConverter *typeConverter,
122-
ConversionPatternRewriter &rewriter) const {
123-
auto loc = op.getLoc();
124-
auto src = op.getSrc();
125-
auto dstLayout = cast<DotOperandEncodingAttr>(op.getType().getEncoding());
126-
auto mmaLayout = cast<NvidiaMmaEncodingAttr>(dstLayout.getParent());
127-
auto llvmElemTy =
128-
typeConverter->convertType(src.getType().getElementType());
129-
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
130-
llvmElemTy, rewriter);
131-
Value res;
132-
if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3
133-
if (mmaLayout.isHopper())
134-
assert(dstLayout.getOpIdx() == 0 &&
135-
"Operand $b in MMAv3 can only be in shared memory");
136-
137-
res = SharedToDotOperandMMAv2OrV3::convertLayout(
138-
dstLayout.getOpIdx(), rewriter, loc, src, dstLayout, smemObj,
139-
typeConverter, getThreadId(rewriter, loc));
140-
} else {
141-
llvm_unreachable("Unsupported mma layout found");
142-
}
143-
rewriter.replaceOp(op, res);
144-
return success();
145-
}
146-
14777
LogicalResult
14878
lowerSharedToDotOperandLL(triton::gpu::LocalLoadOp op,
14979
triton::gpu::LocalLoadOpAdaptor adaptor,
@@ -158,6 +88,7 @@ struct LocalLoadOpConversion
15888
auto shape = dstTy.getShape();
15989
auto rank = dstTy.getRank();
16090
auto kOrder = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2;
91+
auto nonKOrder = dotEnc.getOpIdx() == 0 ? rank - 2 : rank - 1;
16192
auto needTrans = kOrder != sharedEnc.getOrder()[0];
16293

16394
auto llvmElemTy = typeConverter->convertType(dstTy.getElementType());
@@ -169,22 +100,25 @@ struct LocalLoadOpConversion
169100

170101
// Emit ldmatrix load operations for values packed in i32s
171102
SmallVector<Value> elemsI32;
103+
// Typically we load 32x8 to use ldmatrix.x4, but the minimum tile size for
104+
// opIdx=1 is 16x8. Therefore, we use ldmatrix.x2 instead of
105+
// ldmatrix.x4 in this case.
106+
auto shift = dotEnc.getOpIdx() == 1 && shape[kOrder] < (32 * 16 / bitwidth);
172107
auto maxVecElems = 8 * 16 / bitwidth;
173108
bool valid = emitTransferBetweenRegistersAndShared(
174109
ldmatrixLayout, srcTy, llvmElemTy,
175110
/*maxVecElems=*/maxVecElems, smemObj, loc, rewriter, targetInfo,
176111
[&](VectorType vecTy, Value vecAddr) {
177112
auto numElems = vecTy.getNumElements();
178-
auto numElemsI32 = numElems * bitwidth / 32;
113+
auto numElemsI32 = (numElems * bitwidth / 32) >> shift;
179114
auto matTy = LLVM::LLVMStructType::getLiteral(
180115
ctx, SmallVector<Type>(numElemsI32, i32_ty));
181116
auto ldMatrixOp = rewriter.create<nvgpu::LoadMatrixOp>(
182117
loc, matTy, vecAddr, /*needTrans=*/needTrans);
183-
auto resV4 = ldMatrixOp.getResult();
184-
elemsI32.push_back(extract_val(i32_ty, resV4, 0));
185-
elemsI32.push_back(extract_val(i32_ty, resV4, 1));
186-
elemsI32.push_back(extract_val(i32_ty, resV4, 2));
187-
elemsI32.push_back(extract_val(i32_ty, resV4, 3));
118+
auto res = ldMatrixOp.getResult();
119+
for (auto i = 0; i < numElemsI32; ++i) {
120+
elemsI32.push_back(extract_val(i32_ty, res, i));
121+
}
188122
});
189123
assert(valid && "Failed to emit ldmatrix load operations");
190124

0 commit comments

Comments
 (0)