Skip to content

Commit db3649c

Browse files
committed
use A matrix layout order when determining dpas order in accelerate matmul
1 parent f2de7cc commit db3649c

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ namespace {
3131
SmallVector<unsigned>
3232
getWarpsPerTile(tt::DotOp dotOp,
3333
ttg::intel::DpasEncodingAttr::DPASCapability dpasCap,
34-
const ArrayRef<int64_t> shape, unsigned numWarps) {
34+
const ArrayRef<int64_t> shape, unsigned numWarps, const SmallVector<unsigned>& order) {
35+
3536
auto filter = [&dotOp](Operation *op) {
3637
return op->getParentRegion() == dotOp->getParentRegion();
3738
};
@@ -63,7 +64,7 @@ getWarpsPerTile(tt::DotOp dotOp,
6364
uint32_t colRowRatio =
6465
ceil<uint32_t>(dpasCap.executionSize, dpasCap.repeatCount);
6566

66-
int rowDim = rank - 2, colDim = rank - 1;
67+
int rowDim = order[rank - 2], colDim = order[rank - 1];
6768
do {
6869
if (ret[rowDim] * ret[colDim] >= numWarps)
6970
break;
@@ -138,15 +139,18 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
138139
order = triton::gpu::getOrder(cast<RankedTensorType>(aLoad.getType()).getEncoding());
139140
} else {
140141
assert(isa<tt::LoadOp>(aOp) && "expecting load input to DPAS");
141-
order = triton::gpu::getOrder(cast<RankedTensorType>(aLoad.getType()).getEncoding());
142+
assert(aOp->getNumResults() == 1);
143+
auto ret = aOp->getResult(0);
144+
order = triton::gpu::getOrder(cast<RankedTensorType>(ret.getType()).getEncoding());
142145
}
143-
// order = triton::gpu::getOrder(a.getDefiningOp().getEncoding());
144146
llvm::errs() << "a load order: " << order[0] << ", " << order[1] << "\n";
145-
146-
// now find the fast changing dimension from the order
147+
#if 0
148+
const bool aIsTransposed = order.size() == 2 && order[0] == 0 && order[1] == 1;
149+
llvm::errs() << "Transposed? " << aIsTransposed << "\n";
150+
#endif
147151

148152
SmallVector<unsigned> warpsPerTile =
149-
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps);
153+
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps, order);
150154
size_t rank = retShape.size();
151155
SmallVector<unsigned> repCluster(rank, 1);
152156

0 commit comments

Comments
 (0)