@@ -907,34 +907,48 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
907907 }
908908};
909909
910- template <class MatrixOp >
911- struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
910+ static SmallVector<Value> computeDistributedOffsetsForMatrixOp (
911+ PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
912+ Value laneId, ArrayRef<int64_t > payloadShape, ValueRange origOffsets) {
913+ SmallVector<Value> newOffsets;
914+ ;
915+ auto maybeDescOffsets = layout.computeDistributedOffsets (
916+ rewriter, loc, laneId, payloadShape, xegpu::DistributionLevel::WI);
917+ if (failed (maybeDescOffsets))
918+ return {};
919+ assert (maybeDescOffsets.value ().size () == 1 &&
920+ " Expected one set of distributed offsets" );
921+ SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned (
922+ rewriter, loc, getAsOpFoldResult (maybeDescOffsets.value ()[0 ]),
923+ getAsOpFoldResult (origOffsets));
924+ newOffsets = llvm::to_vector (llvm::map_range (
925+ ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
926+ return newOffsets;
927+ }
928+
929+ // / Pattern for distributing xegpu::LoadMatrixOp.
930+ struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
912931 using gpu::WarpDistributionPattern::WarpDistributionPattern;
913932 LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
914933 PatternRewriter &rewriter) const override {
915934 gpu::YieldOp yield = warpOp.getTerminator ();
916935 Operation *lastNode = yield->getPrevNode ();
917- auto matrixOp = dyn_cast_or_null<MatrixOp >(lastNode);
936+ auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp >(lastNode);
918937 if (!matrixOp)
919938 return failure ();
920- constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
921- int operandIdx{-1 };
922-
923- VectorType sgPayloadTy;
924- VectorType warpResultTy;
925- if constexpr (isLoad) {
926- OpOperand *producedByLastLoad = getWarpResult (warpOp, [&](Operation *op) {
927- return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
928- });
929- if (!producedByLastLoad)
930- return rewriter.notifyMatchFailure (
931- warpOp, " The last op is not xegpu::LoadMatrixOp" );
932- operandIdx = producedByLastLoad->getOperandNumber ();
933- sgPayloadTy = dyn_cast<VectorType>(matrixOp.getResult ().getType ());
934- warpResultTy = cast<VectorType>(warpOp.getResult (operandIdx).getType ());
935- } else {
936- sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData ().getType ());
937- }
939+
940+ OpOperand *producedByLastLoad = getWarpResult (warpOp, [&](Operation *op) {
941+ return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
942+ });
943+ if (!producedByLastLoad)
944+ return rewriter.notifyMatchFailure (
945+ warpOp, " The last op is not xegpu::LoadMatrixOp" );
946+ const int operandIdx = producedByLastLoad->getOperandNumber ();
947+
948+ VectorType sgPayloadTy =
949+ dyn_cast<VectorType>(matrixOp.getResult ().getType ());
950+ VectorType warpResultTy =
951+ cast<VectorType>(warpOp.getResult (operandIdx).getType ());
938952 if (!sgPayloadTy)
939953 return rewriter.notifyMatchFailure (
940954 matrixOp, " the matrix op payload must be a vector type" );
@@ -956,21 +970,14 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
956970 getDistVecTypeBasedOnLaneLayout (layout, sgPayloadTy);
957971 if (failed (distPayloadByWarpOpOrFailure))
958972 return rewriter.notifyMatchFailure (
959- matrixOp,
960- " The matrix op payload has no layouts, using defaults instead." );
961-
962- SmallVector<Value> operands;
963- if constexpr (isLoad)
964- operands = {matrixOp.getMemDesc ()};
965- else
966- operands = {matrixOp.getData (), matrixOp.getMemDesc ()};
973+ matrixOp, " The matrix op payload has no layout." );
974+
975+ SmallVector<Value> operands = {matrixOp.getMemDesc ()};
967976 const unsigned offsetsStartIdx = operands.size ();
968977 operands.append (offsetsAsValues);
969978
970979 SmallVector<Type> operandTypes = llvm::to_vector (
971980 llvm::map_range (operands, [](Value v) { return v.getType (); }));
972- if constexpr (!isLoad)
973- operandTypes[0 ] = *distPayloadByWarpOpOrFailure;
974981
975982 SmallVector<size_t > newRetIndices;
976983 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
@@ -986,40 +993,97 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
986993 ValueRange currentOffsets =
987994 ValueRange (newOperands).drop_front (offsetsStartIdx);
988995
989- rewriter.setInsertionPointAfter (newWarpOp);
990996 SmallVector<Value> newOffsets = currentOffsets;
997+ rewriter.setInsertionPointAfter (newWarpOp);
998+
991999 if (!matrixOp.getSubgroupBlockIoAttr ()) {
992- auto maybeDescOffsets = layout.computeDistributedCoords (
993- rewriter, loc, newWarpOp.getLaneid (), sgPayloadTy.getShape (),
994- xegpu::DistributionLevel::WI);
995- if (failed (maybeDescOffsets))
996- return failure ();
997- assert (maybeDescOffsets.value ().size () == 1 &&
998- " Expected same number of offset sets as number of accessed "
999- " sub-tensors or sub-memory descriptors." );
1000- SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned (
1001- rewriter, loc, getAsOpFoldResult (maybeDescOffsets.value ()[0 ]),
1002- offsets);
1003- newOffsets = llvm::to_vector (llvm::map_range (
1004- ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
1000+ newOffsets = computeDistributedOffsetsForMatrixOp (
1001+ rewriter, loc, layout, newWarpOp.getLaneid (), sgPayloadTy.getShape (),
1002+ currentOffsets);
10051003 }
1004+ xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create (
1005+ rewriter, newWarpOp.getLoc (), *distPayloadByWarpOpOrFailure,
1006+ newOperands[0 ], ValueRange (newOffsets), newConstOffsetsAttr,
1007+ matrixOp.getSubgroupBlockIoAttr (), xegpu::DistributeLayoutAttr{});
1008+ // Resolve the output type and replace all uses.
1009+ rewriter.replaceAllUsesWith (
1010+ newWarpOp.getResult (operandIdx),
1011+ resolveDistributedTy (newOp.getResult (), warpResultTy, rewriter));
1012+ return success ();
1013+ }
1014+ };
10061015
1007- if constexpr (isLoad) {
1008- xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create (
1009- rewriter, newWarpOp.getLoc (), *distPayloadByWarpOpOrFailure,
1010- newOperands[0 ], ValueRange (newOffsets), newConstOffsetsAttr,
1011- matrixOp.getSubgroupBlockIoAttr (), xegpu::DistributeLayoutAttr{});
1012- // Resolve the output type and replace all uses.
1013- rewriter.replaceAllUsesWith (
1014- newWarpOp.getResult (operandIdx),
1015- resolveDistributedTy (newOp.getResult (), warpResultTy, rewriter));
1016- } else {
1017- xegpu::StoreMatrixOp::create (
1018- rewriter, loc, TypeRange{}, newOperands[0 ], newOperands[1 ],
1019- ValueRange (newOffsets), newConstOffsetsAttr,
1020- matrixOp.getSubgroupBlockIoAttr (), xegpu::DistributeLayoutAttr{});
1021- rewriter.eraseOp (matrixOp);
1016+ // / Pattern for distributing xegpu::StoreMatrixOp.
1017+ struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
1018+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1019+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
1020+ PatternRewriter &rewriter) const override {
1021+ gpu::YieldOp yield = warpOp.getTerminator ();
1022+ Operation *lastNode = yield->getPrevNode ();
1023+ auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
1024+ if (!matrixOp)
1025+ return failure ();
1026+
1027+ VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData ().getType ());
1028+ if (!sgPayloadTy)
1029+ return rewriter.notifyMatchFailure (
1030+ matrixOp, " the matrix op payload must be a vector type" );
1031+
1032+ auto loc = matrixOp.getLoc ();
1033+ auto offsets = matrixOp.getMixedOffsets ();
1034+ if (offsets.empty ())
1035+ return rewriter.notifyMatchFailure (matrixOp,
1036+ " the store op must have offsets" );
1037+ SmallVector<Value> offsetsAsValues =
1038+ vector::getAsValues (rewriter, matrixOp.getLoc (), offsets);
1039+
1040+ auto layout = matrixOp.getLayoutAttr ();
1041+ if (!layout)
1042+ return rewriter.notifyMatchFailure (
1043+ matrixOp, " the matrix operation lacks layout attribute" );
1044+
1045+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1046+ getDistVecTypeBasedOnLaneLayout (layout, sgPayloadTy);
1047+ if (failed (distPayloadByWarpOpOrFailure))
1048+ return rewriter.notifyMatchFailure (
1049+ matrixOp, " The matrix op payload has no layout." );
1050+
1051+ SmallVector<Value> operands = {matrixOp.getData (), matrixOp.getMemDesc ()};
1052+ const unsigned offsetsStartIdx = operands.size ();
1053+ operands.append (offsetsAsValues);
1054+
1055+ SmallVector<Type> operandTypes = llvm::to_vector (
1056+ llvm::map_range (operands, [](Value v) { return v.getType (); }));
1057+ operandTypes[0 ] = *distPayloadByWarpOpOrFailure;
1058+
1059+ SmallVector<size_t > newRetIndices;
1060+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1061+ rewriter, warpOp, operands, operandTypes, newRetIndices);
1062+ SmallVector<Value> newOperands = llvm::map_to_vector (
1063+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
1064+
1065+ SmallVector<int64_t > newConstOffsets{matrixOp.getConstOffsets ()};
1066+ std::fill (newConstOffsets.begin (), newConstOffsets.end (),
1067+ ShapedType::kDynamic );
1068+ DenseI64ArrayAttr newConstOffsetsAttr =
1069+ rewriter.getDenseI64ArrayAttr (newConstOffsets);
1070+ ValueRange currentOffsets =
1071+ ValueRange (newOperands).drop_front (offsetsStartIdx);
1072+
1073+ SmallVector<Value> newOffsets = currentOffsets;
1074+ rewriter.setInsertionPointAfter (newWarpOp);
1075+
1076+ if (!matrixOp.getSubgroupBlockIoAttr ()) {
1077+ newOffsets = computeDistributedOffsetsForMatrixOp (
1078+ rewriter, loc, layout, newWarpOp.getLaneid (), sgPayloadTy.getShape (),
1079+ currentOffsets);
10221080 }
1081+
1082+ xegpu::StoreMatrixOp::create (
1083+ rewriter, loc, TypeRange{}, newOperands[0 ], newOperands[1 ],
1084+ ValueRange (newOffsets), newConstOffsetsAttr,
1085+ matrixOp.getSubgroupBlockIoAttr (), xegpu::DistributeLayoutAttr{});
1086+ rewriter.eraseOp (matrixOp);
10231087 return success ();
10241088 }
10251089};
@@ -1551,16 +1615,15 @@ struct XeGPUSubgroupDistributePass final
15511615
15521616void xegpu::populateXeGPUSubgroupDistributePatterns (
15531617 RewritePatternSet &patterns) {
1554- patterns
1555- .add <CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1556- DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
1557- VectorMultiReductionDistribution, LoadDistribution,
1558- StoreDistribution, VectorTransposeDistribution,
1559- VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
1560- MatrixOpDistribution<xegpu::StoreMatrixOp>,
1561- MemrefExtractAlignedPointerAsIndexDistribution>(
1562- patterns.getContext (),
1563- /* pattern benefit=*/ regularPatternBenefit);
1618+ patterns.add <CreateNdDescDistribution, StoreNdDistribution,
1619+ LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1620+ GpuBarrierDistribution, VectorMultiReductionDistribution,
1621+ LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1622+ VectorBitcastDistribution, LoadMatrixDistribution,
1623+ StoreMatrixDistribution,
1624+ MemrefExtractAlignedPointerAsIndexDistribution>(
1625+ patterns.getContext (),
1626+ /* pattern benefit=*/ regularPatternBenefit);
15641627 patterns.add <VectorShapeCastDistribution>(
15651628 patterns.getContext (),
15661629 /* pattern benefit=*/ highPatternBenefit);
0 commit comments