@@ -94,6 +94,7 @@ static bool isDPASCompatible(linalg::LinalgOp linalgOp, int kTile,
9494 ArrayRef<int64_t > dpasTile) {
9595 if (!(isa<linalg::MatmulOp>(linalgOp) ||
9696 isa<linalg::BatchReduceMatmulOp>(linalgOp) ||
97+ isa<linalg::MatmulTransposeBOp>(linalgOp) ||
9798 isa<linalg::GenericOp>(linalgOp))) {
9899 return false ;
99100 }
@@ -633,12 +634,11 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
633634//
634635// The descriptor sub-tiles are ordered in row-major fashion with respect to the
635636// whole load tile.
636- static SmallVector<Value> createDescriptorTiles (PatternRewriter &rewriter,
637- Location loc, Value src,
638- ArrayRef<int64_t > loadShape,
639- ArrayRef<int64_t > loadOffsets,
640- ArrayRef<int64_t > descTile,
641- int arrayLength = 1 ) {
637+ static SmallVector<Value>
638+ createDescriptorTiles (PatternRewriter &rewriter, Location loc, Value src,
639+ ArrayRef<int64_t > loadShape,
640+ ArrayRef<int64_t > loadOffsets, ArrayRef<int64_t > descTile,
641+ int arrayLength = 1 , bool transpose = false ) {
642642 assert (arrayLength == 1 && " Array descriptors are not supported" );
643643
644644 auto type = cast<ShapedType>(src.getType ());
@@ -669,6 +669,9 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
669669 Value newRowOffs = rewriter.create <arith::ConstantIndexOp>(loc, i);
670670 for (int j = 0 ; j < loadShape[1 ]; j += descTile[1 ] * arrayLength) {
671671 Value newColOffs = rewriter.create <arith::ConstantIndexOp>(loc, j);
672+ if (transpose) {
673+ std::swap (newRowOffs, newColOffs);
674+ }
672675 auto tile = rewriter
673676 .create <xegpu::UpdateNdOffsetOp>(
674677 loc, descType, rootTile,
@@ -693,7 +696,8 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
693696static SmallVector<Value> createCoarseDscTiles (PatternRewriter &rewriter,
694697 Location loc, Value src,
695698 ArrayRef<int64_t > sgTile,
696- bool isVnni) {
699+ bool isVnni,
700+ bool transpose = false ) {
697701 assert (sgTile.size () <= 2 &&
698702 " Require at most 2D tile size for eltwise lowering" );
699703
@@ -727,7 +731,8 @@ static SmallVector<Value> createCoarseDscTiles(PatternRewriter &rewriter,
727731 // NOLINTEND
728732
729733 return createDescriptorTiles (rewriter, loc, src, sgTile2D, {0 , 0 },
730- {sgLoadRows, sgLoadCols}, arrayLength);
734+ {sgLoadRows, sgLoadCols}, arrayLength,
735+ transpose);
731736}
732737
733738// Return vector type with specified VNNI shape.
@@ -745,7 +750,8 @@ static SmallVector<Value>
745750loadNdDescTiles (PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
746751 xegpu::CachePolicyAttr hint,
747752 std::optional<VnniConfig> vnniConf = std::nullopt ,
748- DenseI64ArrayAttr transpose = nullptr ) {
753+ DenseI64ArrayAttr transpose = nullptr ,
754+ IntegerAttr transpose_bit = nullptr ) {
749755 // Assume all tiles have the same shape.
750756 auto tileType = cast<xegpu::TensorDescType>(loadTiles[0 ].getType ());
751757 assert (llvm::all_of (loadTiles,
@@ -760,7 +766,6 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
760766 *vnniConf);
761767 packedAttr = mlir::UnitAttr::get (rewriter.getContext ());
762768 }
763- IntegerAttr transpose_bit = nullptr ;
764769 SmallVector<Value> loadVec;
765770 for (auto tile : loadTiles) {
766771
@@ -860,13 +865,82 @@ extractVecSubTiles(PatternRewriter &rewriter, Location loc,
860865 return subTiles;
861866}
862867
868+ // Checks whether the given `matmulOperand` is produced by a
869+ // `linalg::TransposeOp` and ensures that the transpose result is only used by
870+ // valid operations, such as `linalg::MatmulOp`, `linalg::BatchReduceMatmulOp`,
871+ // or `linalg::GenericOp`.
872+ //
873+ // If a valid transpose operation is found, the function records it for later
874+ // removal and returns the operand of the transpose operation as the new matrix
875+ // multiplication operand.
876+ static FailureOr<Value> findAndReplaceTranspose (const Value &matmulOperand,
877+ size_t operandIdx,
878+ PatternRewriter &rewriter) {
879+ auto defOp = matmulOperand.getDefiningOp ();
880+ if (!defOp) {
881+ return failure ();
882+ }
883+ linalg::TransposeOp transposeOp = nullptr ;
884+
885+ for (auto x : defOp->getUsers ()) {
886+ if (isa<linalg::TransposeOp>(x)) {
887+ if (transposeOp) {
888+ return rewriter.notifyMatchFailure (
889+ transposeOp, " Only one transpose operation is allowed" );
890+ }
891+
892+ transposeOp = dyn_cast<linalg::TransposeOp>(x);
893+
894+ auto transposeRes = transposeOp.getDpsInits ()[0 ];
895+ // verify that there are no other users of the transpose result
896+ // rather than our matmul
897+ for (auto trUser : transposeRes.getUsers ()) {
898+ if (isa<linalg::MatmulOp>(trUser) ||
899+ isa<linalg::BatchReduceMatmulOp>(trUser) ||
900+ isa<linalg::GenericOp>(trUser)) {
901+ auto matmulOp = dyn_cast<linalg::LinalgOp>(trUser);
902+ auto actualMatmulOperand = matmulOp.getDpsInputs ()[operandIdx];
903+ if (actualMatmulOperand != matmulOperand) {
904+ return rewriter.notifyMatchFailure (
905+ trUser,
906+ " Transpose result is used by more than one matmul operation" );
907+ }
908+ } else if (isa<memref::DeallocOp>(trUser)) {
909+ // allow deallocs as users
910+ continue ;
911+ } else if (isa<linalg::TransposeOp>(trUser)) {
912+ // check if it's the same transpose as we're processing
913+ if (!mlir::OperationEquivalence::isEquivalentTo (trUser, transposeOp,
914+ /* flags=*/ nullptr )) {
915+ return rewriter.notifyMatchFailure (
916+ trUser, " Only one transpose operation is allowed" );
917+ }
918+ continue ;
919+ } else {
920+ return rewriter.notifyMatchFailure (
921+ trUser,
922+ " Transpose result is not allowed to be used by this operation" );
923+ }
924+ }
925+ }
926+ }
927+ if (transposeOp) {
928+ auto ret = transposeOp.getDpsInputs ()[0 ];
929+ rewriter.eraseOp (transposeOp);
930+ return ret;
931+ }
932+ return rewriter.notifyMatchFailure (
933+ defOp, " No transpose operation producing the operand was found" );
934+ }
935+
863936// Create XeGPU DPAS kernel out of GEMM-like operation.
864937static LogicalResult createDPASKernel (linalg::LinalgOp linalgOp,
865938 ArrayRef<int64_t > dpasTile, int kTile ,
866939 int prefetchStages,
867940 PatternRewriter &rewriter) {
868941 assert ((isa<linalg::MatmulOp>(linalgOp) ||
869942 isa<linalg::BatchReduceMatmulOp>(linalgOp) ||
943+ isa<linalg::MatmulTransposeBOp>(linalgOp) ||
870944 isa<linalg::GenericOp>(linalgOp)) &&
871945 " Requires a GEMM-like op for DPAS lowering" );
872946
@@ -877,6 +951,17 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
877951 auto matB = linalgOp.getDpsInputs ()[1 ];
878952 auto matC = linalgOp.getDpsInits ()[0 ];
879953
954+ bool transposeB = false ;
955+ if (isa<linalg::MatmulTransposeBOp>(linalgOp)) {
956+ transposeB = true ;
957+ } else {
958+ auto newMatB = findAndReplaceTranspose (matB, /* operandIdx=*/ 1 , rewriter);
959+ if (!failed (newMatB)) {
960+ matB = *newMatB;
961+ transposeB = true ;
962+ }
963+ }
964+
880965 auto typeA = cast<ShapedType>(matA.getType ());
881966 auto typeC = cast<ShapedType>(matC.getType ());
882967
@@ -961,7 +1046,8 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
9611046
9621047 // Create B sub-tiles.
9631048 SmallVector<Value> tilesB =
964- createCoarseDscTiles (rewriter, loc, matB, {kTile , dimN}, /* isVnni=*/ true );
1049+ createCoarseDscTiles (rewriter, loc, matB, {kTile , dimN},
1050+ /* isVnni=*/ true , transposeB);
9651051
9661052 // Create input prefetch tiles.
9671053 int64_t numThreads = 1 ;
@@ -997,7 +1083,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
9971083 {dimM, dimN}, kTile );
9981084 auto prefetchDescB = createGemmCoopPrefetchTile (
9991085 rewriter, linalgOp, /* inputPos=*/ 1 , numThreads, {blockRows, blockCols},
1000- {dimM, dimN}, kTile );
1086+ (transposeB) ? std::vector<int32_t >{dimM, dimN}
1087+ : std::vector<int32_t >{dimN, dimM},
1088+ kTile );
10011089
10021090 if (succeeded (prefetchDescA) && succeeded (prefetchDescB)) {
10031091 prefetchA = prefetchDescA->getResult ();
@@ -1012,7 +1100,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
10121100 prefetchA = updateTilesOffsets (rewriter, loc, ValueRange{prefetchA},
10131101 {0 , kTile })[0 ];
10141102 prefetchB = updateTilesOffsets (rewriter, loc, ValueRange{prefetchB},
1015- {kTile , 0 })[0 ];
1103+ (transposeB)
1104+ ? std::vector<int64_t >{0 , kTile }
1105+ : std::vector<int64_t >{kTile , 0 })[0 ];
10161106 }
10171107 } else {
10181108 // Disable coop prefetching on failure.
@@ -1083,15 +1173,26 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
10831173 loadNdDescTiles (rewriter, loc, tilesA, readCacheHint);
10841174 auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0 ].getType ());
10851175
1176+ DenseI64ArrayAttr transpose = nullptr ;
1177+ IntegerAttr transpose_bit = nullptr ;
1178+
1179+ if (transposeB) {
1180+ transpose_bit = rewriter.getIntegerAttr (rewriter.getIntegerType (32 ), 32 );
1181+ transpose = DenseI64ArrayAttr::get (rewriter.getContext (), {1 , 0 });
1182+ }
1183+
10861184 // Load B sub-tiles.
10871185 SmallVector<Value> loadVecB =
1088- loadNdDescTiles (rewriter, loc, tilesB, readCacheHint, vnniConfB);
1186+ loadNdDescTiles (rewriter, loc, tilesB, readCacheHint, vnniConfB,
1187+ transpose, transpose_bit);
10891188 auto tileTypeB = cast<xegpu::TensorDescType>(tilesB[0 ].getType ());
10901189
10911190 // Update offsets of the input tiles.
10921191 // Shift along the reduction dimension.
10931192 tilesA = updateTilesOffsets (rewriter, loc, tilesA, {0 , kTile });
1094- tilesB = updateTilesOffsets (rewriter, loc, tilesB, {kTile , 0 });
1193+ tilesB = updateTilesOffsets (rewriter, loc, tilesB,
1194+ transposeB ? std::vector<int64_t >{0 , kTile }
1195+ : std::vector<int64_t >{kTile , 0 });
10951196
10961197 // Prefetch the next set of input tiles.
10971198 if (isCoopPrefetch) {
@@ -1101,7 +1202,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
11011202 prefetchA =
11021203 updateTilesOffsets (rewriter, loc, ValueRange{prefetchA}, {0 , kTile })[0 ];
11031204 prefetchB =
1104- updateTilesOffsets (rewriter, loc, ValueRange{prefetchB}, {kTile , 0 })[0 ];
1205+ updateTilesOffsets (rewriter, loc, ValueRange{prefetchB},
1206+ transposeB ? std::vector<int64_t >{0 , kTile }
1207+ : std::vector<int64_t >{kTile , 0 })[0 ];
11051208 } else {
11061209 // Apply naive prefetching for each subgroup separately.
11071210 prefetchTiles (rewriter, loc, tilesA, readCacheHint);
@@ -1288,7 +1391,7 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern<LinalgOpTy> {
12881391 // Constrain conversion to the supported GEMM-like ops.
12891392 static_assert (
12901393 llvm::is_one_of<LinalgOpTy, linalg::MatmulOp, linalg::BatchReduceMatmulOp,
1291- linalg::GenericOp>::value);
1394+ linalg::GenericOp, linalg::MatmulTransposeBOp >::value);
12921395
12931396 ConvertGemmLikeToXeGPU (MLIRContext *ctx, LinalgToXeGPUOptions options)
12941397 : OpRewritePattern<LinalgOpTy>(ctx), options(options) {}
@@ -1495,8 +1598,9 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
14951598void populateLinalgGemmToXeGPUPatterns (RewritePatternSet &patterns,
14961599 LinalgToXeGPUOptions options) {
14971600 patterns.add <ConvertGemmLikeToXeGPU<linalg::MatmulOp>,
1498- ConvertGemmLikeToXeGPU<linalg::GenericOp>>(patterns.getContext (),
1499- options);
1601+ ConvertGemmLikeToXeGPU<linalg::GenericOp>,
1602+ ConvertGemmLikeToXeGPU<linalg::MatmulTransposeBOp>>(
1603+ patterns.getContext (), options);
15001604}
15011605
15021606void populateLinalgEltwiseToXeGPUPatterns (RewritePatternSet &patterns,
0 commit comments