@@ -906,6 +906,110 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
906906 }
907907};
908908
909+ template <class MatrixOp >
910+ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
911+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
912+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
913+ PatternRewriter &rewriter) const override {
914+ gpu::YieldOp yield = warpOp.getTerminator ();
915+ Operation *lastNode = yield->getPrevNode ();
916+ auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
917+ if (!matrixOp)
918+ return failure ();
919+ constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
920+ int operandIdx{-1 };
921+
922+ VectorType payloadTy;
923+ VectorType warpResultTy;
924+ if constexpr (isLoad) {
925+ OpOperand *producedByLastLoad = getWarpResult (warpOp, [&](Operation *op) {
926+ return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
927+ });
928+ if (!producedByLastLoad)
929+ return rewriter.notifyMatchFailure (
930+ warpOp, " The last op is not xegpu::LoadMatrixOp" );
931+ operandIdx = producedByLastLoad->getOperandNumber ();
932+ payloadTy = dyn_cast<VectorType>(matrixOp.getResult ().getType ());
933+ warpResultTy = cast<VectorType>(warpOp.getResult (operandIdx).getType ());
934+ } else {
935+ payloadTy = dyn_cast<VectorType>(matrixOp.getData ().getType ());
936+ }
937+ if (!payloadTy)
938+ return rewriter.notifyMatchFailure (
939+ matrixOp, " the matrix op payload must be a vector type" );
940+
941+ auto loc = matrixOp.getLoc ();
942+ auto offsets = matrixOp.getMixedOffsets ();
943+ if (offsets.empty ())
944+ return rewriter.notifyMatchFailure (matrixOp,
945+ " the load op must have offsets" );
946+ SmallVector<Value> offsetsAsValues =
947+ vector::getAsValues (rewriter, matrixOp.getLoc (), offsets);
948+
949+ auto layout = matrixOp.getLayoutAttr ();
950+ if (!layout)
951+ return rewriter.notifyMatchFailure (
952+ matrixOp, " the matrix operation lacks layout attribute" );
953+
954+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
955+ getDistVecTypeBasedOnLaneLayout (layout, payloadTy);
956+ if (failed (distPayloadByWarpOpOrFailure))
957+ return rewriter.notifyMatchFailure (
958+ matrixOp,
959+ " The matrix op payload has no layouts, using defaults instead." );
960+
961+ SmallVector<Value> operands;
962+ if constexpr (isLoad)
963+ operands = {matrixOp.getMemDesc ()};
964+ else
965+ operands = {matrixOp.getData (), matrixOp.getMemDesc ()};
966+ const unsigned offsetsStartIdx = operands.size ();
967+ operands.append (offsetsAsValues);
968+
969+ SmallVector<Type> operandTypes = llvm::to_vector (
970+ llvm::map_range (operands, [](Value v) { return v.getType (); }));
971+ if constexpr (!isLoad)
972+ operandTypes[0 ] = *distPayloadByWarpOpOrFailure;
973+
974+ SmallVector<size_t > newRetIndices;
975+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
976+ rewriter, warpOp, operands, operandTypes, newRetIndices);
977+ SmallVector<Value> newOperands = llvm::map_to_vector (
978+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
979+
980+ rewriter.setInsertionPointAfter (newWarpOp);
981+ unsigned operandIdxToModify = offsetsStartIdx + offsetsAsValues.size () - 1 ;
982+ newOperands[operandIdxToModify] = arith::AddIOp::create (
983+ rewriter, loc, rewriter.getIndexType (), newOperands[operandIdxToModify],
984+ newWarpOp.getLaneid ());
985+
986+ SmallVector<int64_t > newConstOffsets{matrixOp.getConstOffsets ()};
987+ std::fill (newConstOffsets.begin (), newConstOffsets.end (),
988+ ShapedType::kDynamic );
989+ DenseI64ArrayAttr newConstOffsetsAttr =
990+ rewriter.getDenseI64ArrayAttr (newConstOffsets);
991+ ValueRange newOffsets = ValueRange (newOperands).drop_front (offsetsStartIdx);
992+
993+ if constexpr (isLoad) {
994+ xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create (
995+ rewriter, newWarpOp.getLoc (), *distPayloadByWarpOpOrFailure,
996+ newOperands[0 ], newOffsets, newConstOffsetsAttr,
997+ matrixOp.getSubgroupBlockIoAttr (), xegpu::DistributeLayoutAttr{});
998+ // Resolve the output type and replace all uses.
999+ rewriter.replaceAllUsesWith (
1000+ newWarpOp.getResult (operandIdx),
1001+ resolveDistributedTy (newOp.getResult (), warpResultTy, rewriter));
1002+ } else {
1003+ xegpu::StoreMatrixOp::create (
1004+ rewriter, loc, TypeRange{}, newOperands[0 ], newOperands[1 ],
1005+ newOffsets, newConstOffsetsAttr, matrixOp.getSubgroupBlockIoAttr (),
1006+ xegpu::DistributeLayoutAttr{});
1007+ rewriter.eraseOp (matrixOp);
1008+ }
1009+ return success ();
1010+ }
1011+ };
1012+
9091013// / Distribute a scattered load op. The logic and requirements are the same as
9101014// / for the scattered store distribution. The warpOp's payload vector is
9111015// / expected to be distributed by the load's result consumer.
@@ -1433,14 +1537,16 @@ struct XeGPUSubgroupDistributePass final
14331537
14341538void xegpu::populateXeGPUSubgroupDistributePatterns (
14351539 RewritePatternSet &patterns) {
1436- patterns.add <CreateNdDescDistribution, StoreNdDistribution,
1437- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1438- GpuBarrierDistribution, VectorMultiReductionDistribution,
1439- LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1440- VectorBitcastDistribution,
1441- MemrefExtractAlignedPointerAsIndexDistribution>(
1442- patterns.getContext (),
1443- /* pattern benefit=*/ regularPatternBenefit);
1540+ patterns
1541+ .add <CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1542+ DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
1543+ VectorMultiReductionDistribution, LoadDistribution,
1544+ StoreDistribution, VectorTransposeDistribution,
1545+ VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
1546+ MatrixOpDistribution<xegpu::StoreMatrixOp>,
1547+ MemrefExtractAlignedPointerAsIndexDistribution>(
1548+ patterns.getContext (),
1549+ /* pattern benefit=*/ regularPatternBenefit);
14441550 patterns.add <VectorShapeCastDistribution>(
14451551 patterns.getContext (),
14461552 /* pattern benefit=*/ highPatternBenefit);
@@ -1462,6 +1568,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
14621568 // Layouts are needed for vector type only.
14631569 if (!isa<VectorType>(operand.get ().getType ()))
14641570 continue ;
1571+ if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
1572+ continue ;
14651573
14661574 auto layout = xegpu::getDistributeLayoutAttr (operand.get ());
14671575 if (!layout) {
0 commit comments