Skip to content

Commit e592cab

Browse files
Merge commit 'd9fd9c59a68dea63f1dcf8e2e20d9eda16589d68'
2 parents 47c150d + d9fd9c5 commit e592cab

File tree

25 files changed

+329
-203
lines changed

25 files changed

+329
-203
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,12 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
10171017
const TargetInfoBase &target,
10181018
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
10191019

1020+
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
1021+
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
1022+
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
1023+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1024+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
1025+
10201026
SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
10211027
triton::gpu::MemDescType srcTy,
10221028
Type elemLlvmTy,

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,11 @@ unsigned getNumCTAs(Attribute layout);
178178
// len(shape) == rank.
179179
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);
180180

181-
// Return the order that represents that the dot operand is in kMajor
181+
// Return the order that represents that the dot operand is in kContig
182182
// (contiguous in the inner dimension) or it's contiguous on the outer
183183
// dimension.
184184
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
185-
bool kMajor);
185+
bool kContig);
186186

187187
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
188188

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
243243

244244
// The primary goal of this function is to efficiently store 2D tiles of a
245245
// tensor into shared memory using the `ldmatrix` instruction.
246-
LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
247-
Attribute dotEnc, ArrayRef<int64_t> shape);
246+
LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
247+
bool needTrans, int32_t elemBitWidth);
248248
} // namespace mlir::triton::gpu
249249

250250
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global",
258258
}];
259259
}
260260

261-
def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> {
261+
def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
262262
let summary = "wait until all the inputs are read.";
263263
let arguments = (ins I32Attr:$pendings);
264264
let description = [{

lib/Analysis/AxisInfo.cpp

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "mlir/Analysis/DataFlowFramework.h"
2-
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
32
#include "llvm/Support/Debug.h"
43
#include "llvm/Support/raw_ostream.h"
54

@@ -232,13 +231,13 @@ class MakeRangeOpAxisInfoVisitor final
232231
}
233232
};
234233

235-
template <typename OpTy>
236-
class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
234+
class ConstantOpAxisInfoVisitor final
235+
: public AxisInfoVisitorImpl<arith::ConstantOp> {
237236
public:
238-
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
237+
using AxisInfoVisitorImpl::AxisInfoVisitorImpl;
239238

240239
AxisInfo
241-
getAxisInfo(OpTy op,
240+
getAxisInfo(arith::ConstantOp op,
242241
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
243242
auto intAttr = dyn_cast<IntegerAttr>(op.getValue());
244243
auto boolAttr = dyn_cast<BoolAttr>(op.getValue());
@@ -323,8 +322,7 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
323322
const AxisInfo &rhs) override {
324323
if (lhs.getConstantValue().has_value() &&
325324
rhs.getConstantValue().has_value()) {
326-
if constexpr (std::is_same_v<OpTy, arith::AddIOp> ||
327-
std::is_same_v<OpTy, LLVM::AddOp>) {
325+
if constexpr (std::is_same_v<OpTy, arith::AddIOp>) {
328326
return {lhs.getConstantValue().value() +
329327
rhs.getConstantValue().value()};
330328
} else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) {
@@ -1013,15 +1011,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10131011
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
10141012
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
10151013
CastOpAxisInfoVisitor<triton::BitcastOp>>();
1016-
// TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
1017-
// when scf.for supports integer induction variables
10181014
visitors.append<MakeRangeOpAxisInfoVisitor>();
1019-
visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>,
1020-
ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
1015+
visitors.append<ConstantOpAxisInfoVisitor>();
10211016
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
10221017
AddSubOpAxisInfoVisitor<arith::AddIOp>,
1023-
AddSubOpAxisInfoVisitor<arith::SubIOp>,
1024-
AddSubOpAxisInfoVisitor<LLVM::AddOp>>();
1018+
AddSubOpAxisInfoVisitor<arith::SubIOp>>();
10251019
visitors.append<MulIOpAxisInfoVisitor>();
10261020
visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>,
10271021
DivOpAxisInfoVisitor<arith::DivUIOp>>();
@@ -1138,17 +1132,11 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
11381132

11391133
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
11401134
Operation *op = blockArg.getOwner()->getParentOp();
1141-
if (auto fun = dyn_cast<FunctionOpInterface>(op))
1142-
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
1143-
&knownContiguity, &knownDivisibility,
1144-
&knownConstancy);
1145-
// llvm codegen check alignment to generate vector load/store
1146-
// would be nice if this wasn't the case
1147-
else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op))
1135+
if (auto fun = dyn_cast<FunctionOpInterface>(op)) {
11481136
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
11491137
&knownContiguity, &knownDivisibility,
11501138
&knownConstancy);
1151-
else if (isa<RegionBranchOpInterface>(op)) {
1139+
} else if (isa<RegionBranchOpInterface>(op)) {
11521140
// scf::ForOp, scf::IfOp, scf::WhileOp
11531141
// Control flow operations are initialized with "unknown" state:
11541142
// the maximum possible divisibility, contiguity, and constancy.

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,9 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
300300
} // namespace
301301

302302
bool emitTransferBetweenRegistersAndShared(
303-
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
304-
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
305-
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
306-
const TargetInfoBase &target,
303+
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
304+
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
305+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
307306
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
308307
MLIRContext *ctx = rewriter.getContext();
309308

@@ -313,8 +312,6 @@ bool emitTransferBetweenRegistersAndShared(
313312
StringAttr kWarp = str_attr("warp");
314313

315314
auto shape = sharedTy.getShape();
316-
LinearLayout regLayout =
317-
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
318315
LinearLayout sharedLayout = triton::gpu::toLinearLayout(
319316
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
320317
LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout);
@@ -360,14 +357,13 @@ bool emitTransferBetweenRegistersAndShared(
360357
// Thus we use `pseudoinvert` instead of `invert` here for simplicity.
361358
auto allocShape = sharedTy.getAllocShape();
362359
LinearLayout invertAllocSharedLayout =
363-
triton::gpu::toLinearLayout(allocShape.take_back(registerTy.getRank()),
360+
triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()),
364361
sharedTy.getEncoding(),
365362
elemLlvmTy.getIntOrFloatBitWidth())
366363
.pseudoinvert();
367364

368365
int numElems = regToSharedLayout.getInDimSize(kRegister);
369366
auto vecTy = vec_ty(elemLlvmTy, vecElems);
370-
Value zero = i32_val(0);
371367
SmallVector<Value> ret;
372368
for (int i = 0; i < numElems / vecElems; i++) {
373369
auto regId = i32_val(i * vecElems);
@@ -379,6 +375,20 @@ bool emitTransferBetweenRegistersAndShared(
379375
return true;
380376
}
381377

378+
bool emitTransferBetweenRegistersAndShared(
379+
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
380+
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
381+
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
382+
const TargetInfoBase &target,
383+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
384+
auto regLayout = triton::gpu::toLinearLayout(
385+
registerTy.getShape(), registerTy.getEncoding(),
386+
elemLlvmTy.getIntOrFloatBitWidth());
387+
return emitTransferBetweenRegistersAndShared(
388+
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
389+
target, perVectorCallback);
390+
}
391+
382392
SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
383393
triton::gpu::MemDescType srcTy,
384394
Type elemLlvmTy,

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,15 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
242242
}
243243

244244
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
245-
bool kMajor) {
246-
// kMajor: if true, the matrix is fastest-running on k,
245+
bool kContig) {
246+
// kContig: if true, the matrix is fastest-running on k,
247247
// otherwise it is on m (resp. n)
248248
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
249249
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
250250
// batch (if rank == 3) is always the slowest running dimension
251251
assert(rank == 2 || rank == 3);
252252
assert(opIdx == 0 || opIdx == 1);
253-
auto rowMajor = bool(opIdx) != kMajor;
253+
auto rowMajor = bool(opIdx) != kContig;
254254
return getMatrixOrder(rank, rowMajor);
255255
}
256256

@@ -283,7 +283,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
283283
}
284284
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
285285
auto rank = dotLayout.getWarpsPerCTA().size();
286-
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
286+
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kContig*/ true);
287287
}
288288
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
289289
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
@@ -1002,7 +1002,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10021002
}
10031003
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
10041004
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1005-
/*kMajor*/ true);
1005+
/*kContig*/ true);
10061006
}
10071007

10081008
LogicalResult DotOperandEncodingAttr::verify(
@@ -2004,7 +2004,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
20042004
SmallVector<unsigned>
20052005
AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
20062006
auto rank = getWarpsPerCTA().size();
2007-
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
2007+
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
20082008
}
20092009

20102010
SmallVector<unsigned>
@@ -2072,7 +2072,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
20722072
SmallVector<unsigned>
20732073
AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
20742074
auto rank = getWarpsPerCTA().size();
2075-
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
2075+
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
20762076
}
20772077

20782078
SmallVector<unsigned>
@@ -2264,7 +2264,7 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
22642264
SmallVector<unsigned>
22652265
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
22662266
auto rank = getWarpsPerCTA().size();
2267-
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
2267+
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
22682268
}
22692269

22702270
SmallVector<unsigned>

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 53 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,80 +1097,80 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
10971097
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
10981098
}
10991099

1100-
LinearLayout chooseLdMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
1101-
SharedEncodingAttr shared,
1102-
DotOperandEncodingAttr dot,
1103-
ArrayRef<int64_t> shape) {
1100+
LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
1101+
ArrayRef<int64_t> shape, bool needTrans,
1102+
int32_t elemBitWidth) {
1103+
auto ctx = dot.getContext();
11041104
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
11051105
auto rank = shape.size();
11061106
auto opIdx = dot.getOpIdx();
1107-
int kDim = opIdx == 0 ? rank - 1 : rank - 2;
1107+
int kDim = (opIdx == 0) ? rank - 1 : rank - 2;
11081108

11091109
StringAttr kReg = S("register");
11101110
StringAttr kLane = S("lane");
11111111
StringAttr kWarp = S("warp");
11121112
StringAttr kBlock = S("block");
1113-
StringAttr kInner = opIdx == 0 ? S("dim1") : S("dim0");
1114-
StringAttr kOuter = opIdx == 0 ? S("dim0") : S("dim1");
1115-
1116-
std::vector<std::vector<int>> basesReg = {{0, 1}, {0, 2}, {0, 4}};
1117-
std::vector<std::vector<int>> basesLane;
1118-
auto numRowsPerTile = 16;
1119-
auto numColsPerTile = 16;
1120-
int vecSize = shared.getVec();
1121-
int perPhase = shared.getPerPhase();
1122-
int maxPhase = shared.getMaxPhase();
1123-
auto warpsPerCTA = mma.getWarpsPerCTA();
1124-
// Construct a 16x16 tile consisting of 4 sub-tiles to use ldmatrix
1113+
StringAttr kInner = opIdx == 0 ? (needTrans ? S("dim0") : S("dim1"))
1114+
: (needTrans ? S("dim1") : S("dim0"));
1115+
StringAttr kOuter = opIdx == 0 ? (needTrans ? S("dim1") : S("dim0"))
1116+
: (needTrans ? S("dim0") : S("dim1"));
1117+
1118+
std::vector<std::vector<int>> basesReg;
1119+
for (int logReg = 0; logReg < llvm::Log2_32(8 * 16 / elemBitWidth);
1120+
logReg++) {
1121+
auto reg = 1 << logReg;
1122+
basesReg.push_back({0, reg});
1123+
}
1124+
std::vector<std::vector<int>> basesLane = {{1, 0}, {2, 0}, {4, 0}};
1125+
int numTileCols;
1126+
// Construct a tile consisting of 4 8x8x16bits sub-tiles to use ldmatrix
11251127
// efficiently. opIdx=0 and opIdx=1 are handled differently.
11261128
if (opIdx == 0) {
1127-
// The matrix elements of thread 0 are distributed in the following pattern:
1129+
// The matrix elements of thread 0 are distributed in the following pattern
1130+
// (fp16):
11281131
//
11291132
// col0 col8
11301133
// row0 reg[0-1] reg[4-5]
11311134
// row8 reg[2-3] reg[6-7]
1132-
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile); logRow++) {
1133-
int row = 1 << logRow;
1134-
basesLane.push_back({row, vecSize * ((row / perPhase) % maxPhase)});
1135-
}
1136-
basesLane.push_back({0, numColsPerTile / 2});
1137-
// Expand the `register` dimension so the size of columns matches `K`.
1138-
for (int logCol = 0; logCol < llvm::Log2_32(shape[kDim] / numColsPerTile);
1139-
logCol++) {
1140-
int col = 1 << logCol;
1141-
basesReg.push_back({0, numColsPerTile * col});
1135+
if (needTrans) {
1136+
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
1137+
"supported in the transposed mode");
1138+
basesLane.push_back({0, 8});
1139+
basesLane.push_back({8, 0});
1140+
} else {
1141+
basesLane.push_back({8, 0});
1142+
basesLane.push_back({0, 8 * 16 / elemBitWidth});
11421143
}
1144+
numTileCols = 16 * 16 / elemBitWidth;
11431145
} else {
1144-
// The matrix elements of thread 0 are distributed in the following pattern:
1146+
// The matrix elements of thread 0 are distributed in the following pattern
1147+
// (fp16):
11451148
//
11461149
// col0 col8 col16 col24
11471150
// row0 reg[0-1] reg[2-3] reg[4-5] reg[6-7]
1148-
// 8x8
1149-
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile / 2); logRow++) {
1150-
int row = 1 << logRow;
1151-
basesLane.push_back({row, vecSize * ((row / perPhase) % maxPhase)});
1152-
}
1153-
// 8x16
1154-
basesLane.push_back({0, numColsPerTile / 2});
1155-
// 8x32
1156-
basesLane.push_back({0, numColsPerTile});
1157-
// Expand the `register` dimension so the size of columns matches `K`.
1158-
for (int logCol = 0;
1159-
logCol < llvm::Log2_32(shape[kDim] / (numColsPerTile * 2)); logCol++) {
1160-
int col = 1 << logCol;
1161-
basesReg.push_back({0, (numColsPerTile * 2) * col});
1151+
if (needTrans) {
1152+
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
1153+
"supported in the transposed mode");
1154+
basesLane.push_back({8, 0});
1155+
basesLane.push_back({16, 0});
1156+
} else {
1157+
basesLane.push_back({0, 8 * 16 / elemBitWidth});
1158+
basesLane.push_back({0, 16 * 16 / elemBitWidth});
11621159
}
1160+
numTileCols = 32 * 16 / elemBitWidth;
11631161
}
1164-
auto layout = LinearLayout(
1165-
{{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}}, {kOuter, kInner});
1162+
// Expand the `register` dimension so the size of columns matches `K`.
1163+
auto layout =
1164+
LinearLayout({{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}},
1165+
{kOuter, kInner}) *
1166+
LinearLayout::identity1D(shape[kDim] / numTileCols, kReg,
1167+
S("dim" + std::to_string(kDim)));
11661168
// Expand the `warp` dimension according to warpsPerCTA.
1169+
auto warpsPerCTA = mma.getWarpsPerCTA();
11671170
layout *= broadcastedDotOperandLayout(ctx, warpsPerCTA, mma.getWarpOrder(),
11681171
kDim, kWarp)
11691172
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
1170-
auto ret = combineCtaCgaWithShape(layout, getCTALayout(dot), shape);
1171-
return ret.transposeOuts({kInner, kOuter})
1172-
.reshapeOuts(
1173-
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
1173+
return combineCtaCgaWithShape(layout, getCTALayout(dot), shape);
11741174
}
11751175

11761176
} // anonymous namespace
@@ -1184,13 +1184,10 @@ LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
11841184
return chooseStMatrixLayoutLeadingOffset(ctx, tensorTy, swizzleByteSize);
11851185
}
11861186

1187-
LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
1188-
Attribute dotEnc, ArrayRef<int64_t> shape) {
1189-
auto shared = cast<SharedEncodingAttr>(sharedEnc);
1190-
auto dot = cast<DotOperandEncodingAttr>(dotEnc);
1191-
assert(!shared.getHasLeadingOffset() &&
1192-
"Ldmatrix does not support leading offset yet");
1193-
return chooseLdMatrixLayoutNoLeadingOffset(ctx, shared, dot, shape);
1187+
LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
1188+
bool needTrans, int32_t elemBitWidth) {
1189+
auto dot = cast<DotOperandEncodingAttr>(enc);
1190+
return chooseDotLdMatrixLayout(dot, shape, needTrans, elemBitWidth);
11941191
}
11951192

11961193
} // namespace mlir::triton::gpu

0 commit comments

Comments
 (0)