Skip to content

Commit d47e314

Browse files
authored
[BACKEND] LL for ldmatrix part2 - support fp8, sliced shared memory, and transposed matrices (#5644)
1 parent 4616092 commit d47e314

File tree

6 files changed

+199
-135
lines changed

6 files changed

+199
-135
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,12 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
10291029
const TargetInfoBase &target,
10301030
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
10311031

1032+
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
1033+
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
1034+
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
1035+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1036+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
1037+
10321038
SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
10331039
triton::gpu::MemDescType srcTy,
10341040
Type elemLlvmTy,

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
243243

244244
// The primary goal of this function is to efficiently store 2D tiles of a
245245
// tensor into shared memory using the `ldmatrix` instruction.
246-
LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
247-
Attribute dotEnc, ArrayRef<int64_t> shape);
246+
LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
247+
bool needTrans, int32_t elemBitWidth);
248248
} // namespace mlir::triton::gpu
249249

250250
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,9 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
300300
} // namespace
301301

302302
bool emitTransferBetweenRegistersAndShared(
303-
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
304-
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
305-
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
306-
const TargetInfoBase &target,
303+
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
304+
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
305+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
307306
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
308307
MLIRContext *ctx = rewriter.getContext();
309308

@@ -313,8 +312,6 @@ bool emitTransferBetweenRegistersAndShared(
313312
StringAttr kWarp = str_attr("warp");
314313

315314
auto shape = sharedTy.getShape();
316-
LinearLayout regLayout =
317-
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
318315
LinearLayout sharedLayout = triton::gpu::toLinearLayout(
319316
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
320317
LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout);
@@ -360,14 +357,13 @@ bool emitTransferBetweenRegistersAndShared(
360357
// Thus we use `pseudoinvert` instead of `invert` here for simplicity.
361358
auto allocShape = sharedTy.getAllocShape();
362359
LinearLayout invertAllocSharedLayout =
363-
triton::gpu::toLinearLayout(allocShape.take_back(registerTy.getRank()),
360+
triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()),
364361
sharedTy.getEncoding(),
365362
elemLlvmTy.getIntOrFloatBitWidth())
366363
.pseudoinvert();
367364

368365
int numElems = regToSharedLayout.getInDimSize(kRegister);
369366
auto vecTy = vec_ty(elemLlvmTy, vecElems);
370-
Value zero = i32_val(0);
371367
SmallVector<Value> ret;
372368
for (int i = 0; i < numElems / vecElems; i++) {
373369
auto regId = i32_val(i * vecElems);
@@ -379,6 +375,20 @@ bool emitTransferBetweenRegistersAndShared(
379375
return true;
380376
}
381377

378+
bool emitTransferBetweenRegistersAndShared(
379+
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
380+
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
381+
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
382+
const TargetInfoBase &target,
383+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
384+
auto regLayout = triton::gpu::toLinearLayout(
385+
registerTy.getShape(), registerTy.getEncoding(),
386+
elemLlvmTy.getIntOrFloatBitWidth());
387+
return emitTransferBetweenRegistersAndShared(
388+
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
389+
target, perVectorCallback);
390+
}
391+
382392
SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
383393
triton::gpu::MemDescType srcTy,
384394
Type elemLlvmTy,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 53 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,80 +1093,80 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
10931093
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
10941094
}
10951095

1096-
LinearLayout chooseLdMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
1097-
SharedEncodingAttr shared,
1098-
DotOperandEncodingAttr dot,
1099-
ArrayRef<int64_t> shape) {
1096+
LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
1097+
ArrayRef<int64_t> shape, bool needTrans,
1098+
int32_t elemBitWidth) {
1099+
auto ctx = dot.getContext();
11001100
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
11011101
auto rank = shape.size();
11021102
auto opIdx = dot.getOpIdx();
1103-
int kDim = opIdx == 0 ? rank - 1 : rank - 2;
1103+
int kDim = (opIdx == 0) ? rank - 1 : rank - 2;
11041104

11051105
StringAttr kReg = S("register");
11061106
StringAttr kLane = S("lane");
11071107
StringAttr kWarp = S("warp");
11081108
StringAttr kBlock = S("block");
1109-
StringAttr kInner = opIdx == 0 ? S("dim1") : S("dim0");
1110-
StringAttr kOuter = opIdx == 0 ? S("dim0") : S("dim1");
1111-
1112-
std::vector<std::vector<int>> basesReg = {{0, 1}, {0, 2}, {0, 4}};
1113-
std::vector<std::vector<int>> basesLane;
1114-
auto numRowsPerTile = 16;
1115-
auto numColsPerTile = 16;
1116-
int vecSize = shared.getVec();
1117-
int perPhase = shared.getPerPhase();
1118-
int maxPhase = shared.getMaxPhase();
1119-
auto warpsPerCTA = mma.getWarpsPerCTA();
1120-
// Construct a 16x16 tile consisting of 4 sub-tiles to use ldmatrix
1109+
StringAttr kInner = opIdx == 0 ? (needTrans ? S("dim0") : S("dim1"))
1110+
: (needTrans ? S("dim1") : S("dim0"));
1111+
StringAttr kOuter = opIdx == 0 ? (needTrans ? S("dim1") : S("dim0"))
1112+
: (needTrans ? S("dim0") : S("dim1"));
1113+
1114+
std::vector<std::vector<int>> basesReg;
1115+
for (int logReg = 0; logReg < llvm::Log2_32(8 * 16 / elemBitWidth);
1116+
logReg++) {
1117+
auto reg = 1 << logReg;
1118+
basesReg.push_back({0, reg});
1119+
}
1120+
std::vector<std::vector<int>> basesLane = {{1, 0}, {2, 0}, {4, 0}};
1121+
int numTileCols;
1122+
// Construct a tile consisting of 4 8x8x16bits sub-tiles to use ldmatrix
11211123
// efficiently. opIdx=0 and opIdx=1 are handled differently.
11221124
if (opIdx == 0) {
1123-
// The matrix elements of thread 0 are distributed in the following pattern:
1125+
// The matrix elements of thread 0 are distributed in the following pattern
1126+
// (fp16):
11241127
//
11251128
// col0 col8
11261129
// row0 reg[0-1] reg[4-5]
11271130
// row8 reg[2-3] reg[6-7]
1128-
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile); logRow++) {
1129-
int row = 1 << logRow;
1130-
basesLane.push_back({row, vecSize * ((row / perPhase) % maxPhase)});
1131-
}
1132-
basesLane.push_back({0, numColsPerTile / 2});
1133-
// Expand the `register` dimension so the size of columns matches `K`.
1134-
for (int logCol = 0; logCol < llvm::Log2_32(shape[kDim] / numColsPerTile);
1135-
logCol++) {
1136-
int col = 1 << logCol;
1137-
basesReg.push_back({0, numColsPerTile * col});
1131+
if (needTrans) {
1132+
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
1133+
"supported in the transposed mode");
1134+
basesLane.push_back({0, 8});
1135+
basesLane.push_back({8, 0});
1136+
} else {
1137+
basesLane.push_back({8, 0});
1138+
basesLane.push_back({0, 8 * 16 / elemBitWidth});
11381139
}
1140+
numTileCols = 16 * 16 / elemBitWidth;
11391141
} else {
1140-
// The matrix elements of thread 0 are distributed in the following pattern:
1142+
// The matrix elements of thread 0 are distributed in the following pattern
1143+
// (fp16):
11411144
//
11421145
// col0 col8 col16 col24
11431146
// row0 reg[0-1] reg[2-3] reg[4-5] reg[6-7]
1144-
// 8x8
1145-
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile / 2); logRow++) {
1146-
int row = 1 << logRow;
1147-
basesLane.push_back({row, vecSize * ((row / perPhase) % maxPhase)});
1148-
}
1149-
// 8x16
1150-
basesLane.push_back({0, numColsPerTile / 2});
1151-
// 8x32
1152-
basesLane.push_back({0, numColsPerTile});
1153-
// Expand the `register` dimension so the size of columns matches `K`.
1154-
for (int logCol = 0;
1155-
logCol < llvm::Log2_32(shape[kDim] / (numColsPerTile * 2)); logCol++) {
1156-
int col = 1 << logCol;
1157-
basesReg.push_back({0, (numColsPerTile * 2) * col});
1147+
if (needTrans) {
1148+
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
1149+
"supported in the transposed mode");
1150+
basesLane.push_back({8, 0});
1151+
basesLane.push_back({16, 0});
1152+
} else {
1153+
basesLane.push_back({0, 8 * 16 / elemBitWidth});
1154+
basesLane.push_back({0, 16 * 16 / elemBitWidth});
11581155
}
1156+
numTileCols = 32 * 16 / elemBitWidth;
11591157
}
1160-
auto layout = LinearLayout(
1161-
{{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}}, {kOuter, kInner});
1158+
// Expand the `register` dimension so the size of columns matches `K`.
1159+
auto layout =
1160+
LinearLayout({{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}},
1161+
{kOuter, kInner}) *
1162+
LinearLayout::identity1D(shape[kDim] / numTileCols, kReg,
1163+
S("dim" + std::to_string(kDim)));
11621164
// Expand the `warp` dimension according to warpsPerCTA.
1165+
auto warpsPerCTA = mma.getWarpsPerCTA();
11631166
layout *= broadcastedDotOperandLayout(ctx, warpsPerCTA, mma.getWarpOrder(),
11641167
kDim, kWarp)
11651168
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
1166-
auto ret = combineCtaCgaWithShape(layout, getCTALayout(dot), shape);
1167-
return ret.transposeOuts({kInner, kOuter})
1168-
.reshapeOuts(
1169-
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
1169+
return combineCtaCgaWithShape(layout, getCTALayout(dot), shape);
11701170
}
11711171

11721172
} // anonymous namespace
@@ -1180,13 +1180,10 @@ LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
11801180
return chooseStMatrixLayoutLeadingOffset(ctx, tensorTy, swizzleByteSize);
11811181
}
11821182

1183-
LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
1184-
Attribute dotEnc, ArrayRef<int64_t> shape) {
1185-
auto shared = cast<SharedEncodingAttr>(sharedEnc);
1186-
auto dot = cast<DotOperandEncodingAttr>(dotEnc);
1187-
assert(!shared.getHasLeadingOffset() &&
1188-
"Ldmatrix does not support leading offset yet");
1189-
return chooseLdMatrixLayoutNoLeadingOffset(ctx, shared, dot, shape);
1183+
LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
1184+
bool needTrans, int32_t elemBitWidth) {
1185+
auto dot = cast<DotOperandEncodingAttr>(enc);
1186+
return chooseDotLdMatrixLayout(dot, shape, needTrans, elemBitWidth);
11901187
}
11911188

11921189
} // namespace mlir::triton::gpu

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -852,18 +852,77 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
852852
// -----
853853

854854
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
855-
#shared0 = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
855+
#shared0 = #ttg.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
856+
#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
857+
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
858+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
859+
#smem = #ttg.shared_memory
860+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
861+
// CHECK-LABEL: convert_dot_ldmatrix
862+
tt.func @convert_dot_ldmatrix(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
863+
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
864+
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
865+
// CHECK: nvgpu.ldmatrix
866+
// CHECK: nvgpu.ldmatrix
867+
// CHECK-NOT: nvgpu.ldmatrix
868+
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
869+
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
870+
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
871+
872+
// CHECK: llvm.inline_asm
873+
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
874+
// CHECK: llvm.inline_asm
875+
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
876+
%D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
877+
878+
tt.return
879+
}
880+
}
881+
882+
// -----
883+
884+
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
885+
#shared0 = #ttg.shared<{vec = 8, perPhase=1, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
856886
#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
857887
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
858888
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
859889
#smem = #ttg.shared_memory
860890
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
861891
// CHECK-LABEL: convert_dot
862-
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
892+
tt.func @convert_dot_ldmatrix_swizzle(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
863893
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
864894
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
865895
// CHECK: nvgpu.ldmatrix
866896
// CHECK: nvgpu.ldmatrix
897+
// CHECK-NOT: nvgpu.ldmatrix
898+
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
899+
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
900+
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
901+
902+
// CHECK: llvm.inline_asm
903+
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
904+
// CHECK: llvm.inline_asm
905+
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
906+
%D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
907+
908+
tt.return
909+
}
910+
}
911+
912+
// -----
913+
914+
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
915+
#shared0 = #ttg.shared<{vec = 1, perPhase=1, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
916+
#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
917+
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
918+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
919+
#smem = #ttg.shared_memory
920+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
921+
// CHECK-LABEL: convert_dot
922+
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
923+
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
924+
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
925+
// CHECK-NOT: nvgpu.ldmatrix
867926
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
868927
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
869928
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -905,7 +964,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
905964
// -----
906965

907966
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
908-
#shared0 = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
967+
#shared0 = #ttg.shared<{vec = 16, perPhase=1, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
909968
#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
910969
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=4}>
911970
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=4}>
@@ -1206,7 +1265,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
12061265
// -----
12071266

12081267
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
1209-
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
1268+
#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
12101269
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
12111270
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
12121271
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=2}>
@@ -1255,7 +1314,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12551314
// -----
12561315

12571316
#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
1258-
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
1317+
#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
12591318
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
12601319
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
12611320
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=1}>
@@ -1744,7 +1803,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
17441803
// -----
17451804

17461805
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
1747-
#shared0 = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
1806+
#shared0 = #ttg.shared<{vec = 8, perPhase=1, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
17481807
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
17491808
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
17501809
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=2}>

0 commit comments

Comments
 (0)