@@ -807,6 +807,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807807 }
808808};
809809
810+ struct StoreDistribution final : public gpu::WarpDistributionPattern {
811+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
812+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
813+ PatternRewriter &rewriter) const override {
814+ auto yield = cast<gpu::YieldOp>(
815+ warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
816+ Operation *lastNode = yield->getPrevNode ();
817+ auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
818+ if (!storeScatterOp)
819+ return failure ();
820+ else if (!storeScatterOp.getOffsets ())
821+ return rewriter.notifyMatchFailure (storeScatterOp,
822+ " Store op must have offsets argument" );
823+ else if (cast<VectorType>(storeScatterOp.getOffsets ().getType ())
824+ .getRank () != 1 )
825+ return rewriter.notifyMatchFailure (storeScatterOp,
826+ " Expected 1D offsets vector" );
827+
828+ VectorType storeVecTy =
829+ cast<VectorType>(storeScatterOp.getValue ().getType ());
830+ assert (storeVecTy.getRank () <= 2 &&
831+ " Expected at most 2D result at SG level" );
832+ VectorType distStoreVecTy;
833+ if (storeVecTy.getRank () == 2 )
834+ distStoreVecTy = VectorType::Builder (storeVecTy).dropDim (0 );
835+ else // rank 1
836+ distStoreVecTy = VectorType::Builder (storeVecTy).setDim (0 , 1 );
837+
838+ SmallVector<size_t > newRetIndices;
839+ SmallVector<Value> operands =
840+ llvm::to_vector_of<Value>(storeScatterOp->getOperands ());
841+ SmallVector<Type> operandTypes =
842+ llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes ());
843+ operandTypes[0 ] = distStoreVecTy;
844+
845+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
846+ rewriter, warpOp, operands, operandTypes, newRetIndices);
847+ SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
848+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
849+
850+ Value offsetsVec = newStoreScatterOpOperands[2 ];
851+ Value maskVec = newStoreScatterOpOperands[3 ];
852+
853+ auto loc = newWarpOp.getLoc ();
854+ Value laneId = warpOp.getLaneid ();
855+ rewriter.setInsertionPointAfter (newWarpOp);
856+ Value laneOffset =
857+ vector::ExtractOp::create (rewriter, loc, offsetsVec, laneId);
858+ laneOffset = vector::BroadcastOp::create (
859+ rewriter, loc, VectorType::get ({1 }, laneOffset.getType ()), laneOffset);
860+ Value laneMask = vector::ExtractOp::create (rewriter, loc, maskVec, laneId);
861+ laneMask = vector::BroadcastOp::create (
862+ rewriter, loc, VectorType::get ({1 }, laneMask.getType ()), laneMask);
863+ newStoreScatterOpOperands[2 ] = laneOffset;
864+ newStoreScatterOpOperands[3 ] = laneMask;
865+
866+ xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
867+ rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
868+ storeScatterOp->getAttrs ());
869+ xegpu::removeLayoutAttrs (newOp);
870+ rewriter.eraseOp (storeScatterOp);
871+ return success ();
872+ }
873+ };
874+
875+ struct LoadDistribution final : public gpu::WarpDistributionPattern {
876+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
877+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
878+ PatternRewriter &rewriter) const override {
879+ OpOperand *yieldOperand = getWarpResult (warpOp, [&](Operation *op) {
880+ if (!isa<xegpu::LoadGatherOp>(op))
881+ return false ;
882+ auto yield = cast<gpu::YieldOp>(
883+ warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
884+ return yield->getPrevNode () == op;
885+ });
886+ if (!yieldOperand)
887+ return rewriter.notifyMatchFailure (
888+ warpOp, " warp result is not a xegpu::LoadGatherOp op" );
889+
890+ auto loadGatherOp =
891+ yieldOperand->get ().getDefiningOp <xegpu::LoadGatherOp>();
892+ if (!loadGatherOp.getOffsets ())
893+ return rewriter.notifyMatchFailure (loadGatherOp,
894+ " Load op must have offsets argument" );
895+ else if (cast<VectorType>(loadGatherOp.getOffsets ().getType ()).getRank () !=
896+ 1 )
897+ return rewriter.notifyMatchFailure (loadGatherOp,
898+ " Expected 1D offsets vector" );
899+
900+ SmallVector<size_t > newRetIndices;
901+ SmallVector<Value> operands =
902+ llvm::to_vector_of<Value>(loadGatherOp->getOperands ());
903+ SmallVector<Type> operandTypes =
904+ llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes ());
905+
906+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
907+ rewriter, warpOp, operands, operandTypes, newRetIndices);
908+
909+ SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector (
910+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
911+
912+ const unsigned operandIdx = yieldOperand->getOperandNumber ();
913+ VectorType loadVecTy =
914+ cast<VectorType>(warpOp.getResult (operandIdx).getType ());
915+ assert (loadVecTy.getRank () == 1 && " Expected a distributed vector" );
916+
917+ Value offsetsVec = newLoadGatherOperands[1 ];
918+ Value maskVec = newLoadGatherOperands[2 ];
919+ auto loc = newWarpOp.getLoc ();
920+ Value laneId = warpOp.getLaneid ();
921+ rewriter.setInsertionPointAfter (newWarpOp);
922+ Value laneOffset =
923+ vector::ExtractOp::create (rewriter, loc, offsetsVec, laneId);
924+ laneOffset = vector::BroadcastOp::create (
925+ rewriter, loc, VectorType::get ({1 }, laneOffset.getType ()), laneOffset);
926+ Value laneMask = vector::ExtractOp::create (rewriter, loc, maskVec, laneId);
927+ laneMask = vector::BroadcastOp::create (
928+ rewriter, loc, VectorType::get ({1 }, laneMask.getType ()), laneMask);
929+ newLoadGatherOperands[1 ] = laneOffset;
930+ newLoadGatherOperands[2 ] = laneMask;
931+
932+ xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
933+ loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs ());
934+ Value distributedVal = newWarpOp.getResult (operandIdx);
935+ rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
936+ return success ();
937+ }
938+ };
939+
810940} // namespace
811941
812942namespace {
@@ -819,10 +949,11 @@ struct XeGPUSubgroupDistributePass final
819949
820950void xegpu::populateXeGPUSubgroupDistributePatterns (
821951 RewritePatternSet &patterns) {
822- patterns.add <CreateNdDescDistribution, StoreNdDistribution,
823- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824- UpdateNdOffsetDistribution, GpuBarrierDistribution>(
825- patterns.getContext ());
952+ patterns
953+ .add <CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
954+ DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
955+ GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
956+ patterns.getContext ());
826957}
827958
828959void XeGPUSubgroupDistributePass::runOnOperation () {
@@ -837,6 +968,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
837968 if (!isa<VectorType>(operand.get ().getType ()))
838969 continue ;
839970
971+ if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
972+ continue ;
840973 auto layout =
841974 xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
842975 if (!layout) {
0 commit comments