@@ -807,26 +807,47 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807807 }
808808};
809809
810+ // / Distribute a scattered store op. The offsets argument is required.
811+ // / Both offset and mask vectors must be 1D and have #subgroup_size elements.
812+ // / The layouts are fixed and implicit: one offset/mask per lane.
813+ // / The pass changes the offset/mask vector shapes to a
814+ // / single-element vector, **it is assumed that their producer will also be
815+ // / distributed**. The payload vector also has a fixed distribution:
816+ // / no chunk size -> vector of one element.
817+ // / chunk size -> vector of the innermost dimension of the SG-payload.
818+ // / Example 1 (no chunk size):
819+ // / %mask = producer_op : vector<16xi1>
820+ // / %offset = producer_op : vector<16xindex>
821+ // / xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822+ // / memref<256xf16>, vector<16xindex>, vector<16xi1>
823+ // / To
824+ // / %mask = producer_op : vector<1xi1>
825+ // / %offset = producer_op : vector<1xindex>
826+ // / xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827+ // / memref<256xf16>, vector<1xindex>, vector<1xi1>
828+ // / Example 2 (chunk size, same mask and offsets):
829+ // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830+ // / vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831+ // / To
832+ // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833+ // / vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
810834struct StoreDistribution final : public gpu::WarpDistributionPattern {
811835 using gpu::WarpDistributionPattern::WarpDistributionPattern;
812836 LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
813837 PatternRewriter &rewriter) const override {
814- auto yield = cast<gpu::YieldOp>(
815- warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
816- Operation *lastNode = yield->getPrevNode ();
838+ Operation *lastNode = warpOp.getTerminator ()->getPrevNode ();
817839 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
818840 if (!storeScatterOp)
819841 return failure ();
820- if (! storeScatterOp.getOffsets ())
821- return rewriter. notifyMatchFailure (storeScatterOp,
822- " Store op must have offsets argument " );
823- VectorType offsetsTy =
824- cast<VectorType>(storeScatterOp. getOffsets () .getType ());
842+ auto offsets = storeScatterOp.getOffsets ();
843+ if (!offsets || !isa<VectorType>(offsets. getType ()))
844+ return rewriter. notifyMatchFailure (
845+ storeScatterOp, " Store op must have a vector of offsets argument " );
846+ VectorType offsetsTy = cast<VectorType>(offsets .getType ());
825847 if (offsetsTy.getRank () != 1 )
826848 return rewriter.notifyMatchFailure (storeScatterOp,
827849 " Expected 1D offsets vector" );
828- VectorType storeVecTy =
829- cast<VectorType>(storeScatterOp.getValue ().getType ());
850+ VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
830851 assert (storeVecTy.getRank () <= 2 &&
831852 " Expected at most 2D result at SG level" );
832853 VectorType distStoreVecTy;
@@ -837,80 +858,99 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
837858
838859 SmallVector<size_t > newRetIndices;
839860 SmallVector<Value> operands = storeScatterOp->getOperands ();
840- SmallVector<Type> operandTypes =
861+ SmallVector<Type> operandTypesToYield =
841862 llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes ());
842- operandTypes[0 ] = distStoreVecTy;
843- // Assume offset and mask pproducers will be distributed as well.
844- operandTypes[2 ] = VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
845- operandTypes[3 ] = VectorType::get (
863+ operandTypesToYield[0 ] = distStoreVecTy;
864+ // Assume offset and mask producers will be distributed as well.
865+ operandTypesToYield[2 ] =
866+ VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
867+ operandTypesToYield[3 ] = VectorType::get (
846868 {1 }, getElementTypeOrSelf (storeScatterOp.getMask ().getType ()));
847869
848870 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
849- rewriter, warpOp, operands, operandTypes , newRetIndices);
871+ rewriter, warpOp, operands, operandTypesToYield , newRetIndices);
850872 SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
851873 newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
852874
853- auto loc = newWarpOp.getLoc ();
854875 rewriter.setInsertionPointAfter (newWarpOp);
855876 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
856- rewriter, loc , TypeRange{}, newStoreScatterOpOperands,
877+ rewriter, newWarpOp. getLoc () , TypeRange{}, newStoreScatterOpOperands,
857878 storeScatterOp->getAttrs ());
858879 xegpu::removeLayoutAttrs (newOp);
859880 rewriter.eraseOp (storeScatterOp);
860881 return success ();
861882 }
862883};
863884
885+ // / Distribute a scattered load op. The logic and requirements are the same as
886+ // / for the scattered store distribution. The warpOp's payload vector is
887+ // / expected to be distributed by the load's result consumer.
888+ // / Example 1 (no chunk size):
889+ // / %mask = producer_op : vector<16xi1>
890+ // / %offset = producer_op : vector<16xindex>
891+ // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
892+ // / vector<16xindex>, vector<16xi1> -> vector<16xf16>
893+ // / To
894+ // / %mask = producer_op : vector<1xi1>
895+ // / %offset = producer_op : vector<1xindex>
896+ // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
897+ // / vector<1xindex>, vector<1xi1> -> vector<1xf16>
898+ // / Example 2 (chunk size, same mask and offsets):
899+ // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
900+ // / memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
901+ // / To
902+ // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
903+ // / memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
864904struct LoadDistribution final : public gpu::WarpDistributionPattern {
865905 using gpu::WarpDistributionPattern::WarpDistributionPattern;
866906 LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
867907 PatternRewriter &rewriter) const override {
868- OpOperand *yieldOperand = getWarpResult (warpOp, [&](Operation *op) {
869- if (!isa<xegpu::LoadGatherOp>(op))
870- return false ;
871- auto yield = cast<gpu::YieldOp>(
872- warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
873- return yield->getPrevNode () == op;
908+ OpOperand *producedByLastLoad = getWarpResult (warpOp, [&](Operation *op) {
909+ // Check if the yield operand that was produced by the *last* scattered
910+ // load op to avoid sinking it before barriers (maintain memory order).
911+ return isa<xegpu::LoadGatherOp>(op) &&
912+ warpOp.getTerminator ()->getPrevNode () == op;
874913 });
875- if (!yieldOperand )
914+ if (!producedByLastLoad )
876915 return rewriter.notifyMatchFailure (
877- warpOp, " warp result is not a xegpu::LoadGatherOp op " );
916+ warpOp, " The last op is not xegpu::LoadGatherOp" );
878917
879918 auto loadGatherOp =
880- yieldOperand ->get ().getDefiningOp <xegpu::LoadGatherOp>();
881- if (! loadGatherOp.getOffsets ())
882- return rewriter. notifyMatchFailure (loadGatherOp,
883- " Load op must have offsets argument " );
884- VectorType offsetsTy =
885- cast<VectorType>(loadGatherOp. getOffsets () .getType ());
919+ producedByLastLoad ->get ().getDefiningOp <xegpu::LoadGatherOp>();
920+ auto offsets = loadGatherOp.getOffsets ();
921+ if (!offsets || !isa<VectorType>(offsets. getType ()))
922+ return rewriter. notifyMatchFailure (
923+ loadGatherOp, " Load op must have a vector of offsets argument " );
924+ VectorType offsetsTy = cast<VectorType>(offsets .getType ());
886925 if (offsetsTy.getRank () != 1 )
887926 return rewriter.notifyMatchFailure (loadGatherOp,
888927 " Expected 1D offsets vector" );
889928
890929 SmallVector<size_t > newRetIndices;
891930 SmallVector<Value> operands = loadGatherOp->getOperands ();
892- SmallVector<Type> operandTypes =
931+ SmallVector<Type> operandTypesToYield =
893932 llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes ());
894- // Assume offset and mask pproducers will be distributed as well.
895- operandTypes[1 ] = VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
896- operandTypes[2 ] = VectorType::get (
897- {1 }, getElementTypeOrSelf (loadGatherOp.getMask ().getType ()));
933+ // Assume offset and mask producers will be distributed as well.
934+ operandTypesToYield[1 ] =
935+ VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
936+ operandTypesToYield[2 ] =
937+ VectorType::get ({1 }, getElementTypeOrSelf (loadGatherOp.getMaskType ()));
898938
899939 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
900- rewriter, warpOp, operands, operandTypes , newRetIndices);
940+ rewriter, warpOp, operands, operandTypesToYield , newRetIndices);
901941
902942 SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector (
903943 newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
904944
905- const unsigned operandIdx = yieldOperand ->getOperandNumber ();
945+ const unsigned operandIdx = producedByLastLoad ->getOperandNumber ();
906946 VectorType loadVecTy =
907947 cast<VectorType>(warpOp.getResult (operandIdx).getType ());
908948 assert (loadVecTy.getRank () == 1 && " Expected a distributed vector" );
909949
910- auto loc = newWarpOp.getLoc ();
911950 rewriter.setInsertionPointAfter (newWarpOp);
912951 xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
913- loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs ());
952+ newWarpOp.getLoc (), loadVecTy, newLoadGatherOperands,
953+ loadGatherOp->getAttrs ());
914954 Value distributedVal = newWarpOp.getResult (operandIdx);
915955 rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
916956 return success ();
@@ -948,6 +988,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
948988 if (!isa<VectorType>(operand.get ().getType ()))
949989 continue ;
950990
991+ // Vectors operands of these ops have a fixed and implicit layout.
951992 if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
952993 continue ;
953994 auto layout =
0 commit comments