@@ -31,7 +31,8 @@ namespace {
3131SmallVector<unsigned >
3232getWarpsPerTile (tt::DotOp dotOp,
3333 ttg::intel::DpasEncodingAttr::DPASCapability dpasCap,
34- const ArrayRef<int64_t > shape, unsigned numWarps) {
34+ const ArrayRef<int64_t > shape, unsigned numWarps, const SmallVector<unsigned >& order) {
35+
3536 auto filter = [&dotOp](Operation *op) {
3637 return op->getParentRegion () == dotOp->getParentRegion ();
3738 };
@@ -63,7 +64,7 @@ getWarpsPerTile(tt::DotOp dotOp,
6364 uint32_t colRowRatio =
6465 ceil<uint32_t >(dpasCap.executionSize , dpasCap.repeatCount );
6566
66- int rowDim = rank - 2 , colDim = rank - 1 ;
67+ int rowDim = order[ rank - 2 ] , colDim = order[ rank - 1 ] ;
6768 do {
6869 if (ret[rowDim] * ret[colDim] >= numWarps)
6970 break ;
@@ -138,15 +139,18 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
138139 order = triton::gpu::getOrder (cast<RankedTensorType>(aLoad.getType ()).getEncoding ());
139140 } else {
140141 assert (isa<tt::LoadOp>(aOp) && " expecting load input to DPAS" );
141- order = triton::gpu::getOrder (cast<RankedTensorType>(aLoad.getType ()).getEncoding ());
142+ assert (aOp->getNumResults () == 1 );
143+ auto ret = aOp->getResult (0 );
144+ order = triton::gpu::getOrder (cast<RankedTensorType>(ret.getType ()).getEncoding ());
142145 }
143- // order = triton::gpu::getOrder(a.getDefiningOp().getEncoding());
144146 llvm::errs () << " a load order: " << order[0 ] << " , " << order[1 ] << " \n " ;
145-
146- // now find the fast changing dimension from the order
147+ #if 0
148+ const bool aIsTransposed = order.size() == 2 && order[0] == 0 && order[1] == 1;
149+ llvm::errs() << "Transposed? " << aIsTransposed << "\n";
150+ #endif
147151
148152 SmallVector<unsigned > warpsPerTile =
149- getWarpsPerTile (dotOp, dpasCap, retShape, numWarps);
153+ getWarpsPerTile (dotOp, dpasCap, retShape, numWarps, order );
150154 size_t rank = retShape.size ();
151155 SmallVector<unsigned > repCluster (rank, 1 );
152156
0 commit comments