@@ -807,26 +807,47 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807
807
}
808
808
};
809
809
810
+ // / Distribute a scattered store op. The offsets argument is required.
811
+ // / Both offset and mask vectors must be 1D and have #subgroup_size elements.
812
+ // / The layouts are fixed and implicit: one offset/mask per lane.
813
+ // / The pass changes the offset/mask vector shapes to a
814
+ // / single-element vector, **it is assumed that their producer will also be
815
+ // / distributed**. The payload vector also has a fixed distribution:
816
+ // / no chunk size -> vector of one element.
817
+ // / chunk size -> vector of the innermost dimension of the SG-payload.
818
+ // / Example 1 (no chunk size):
819
+ // / %mask = producer_op : vector<16xi1>
820
+ // / %offset = producer_op : vector<16xindex>
821
+ // / xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822
+ // / memref<256xf16>, vector<16xindex>, vector<16xi1>
823
+ // / To
824
+ // / %mask = producer_op : vector<1xi1>
825
+ // / %offset = producer_op : vector<1xindex>
826
+ // / xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827
+ // / memref<256xf16>, vector<1xindex>, vector<1xi1>
828
+ // / Example 2 (chunk size, same mask and offsets):
829
+ // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830
+ // / vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831
+ // / To
832
+ // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833
+ // / vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
810
834
struct StoreDistribution final : public gpu::WarpDistributionPattern {
811
835
using gpu::WarpDistributionPattern::WarpDistributionPattern;
812
836
LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
813
837
PatternRewriter &rewriter) const override {
814
- auto yield = cast<gpu::YieldOp>(
815
- warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
816
- Operation *lastNode = yield->getPrevNode ();
838
+ Operation *lastNode = warpOp.getTerminator ()->getPrevNode ();
817
839
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
818
840
if (!storeScatterOp)
819
841
return failure ();
820
- if (! storeScatterOp.getOffsets ())
821
- return rewriter. notifyMatchFailure (storeScatterOp,
822
- " Store op must have offsets argument " );
823
- VectorType offsetsTy =
824
- cast<VectorType>(storeScatterOp. getOffsets () .getType ());
842
+ auto offsets = storeScatterOp.getOffsets ();
843
+ if (!offsets || !isa<VectorType>(offsets. getType ()))
844
+ return rewriter. notifyMatchFailure (
845
+ storeScatterOp, " Store op must have a vector of offsets argument " );
846
+ VectorType offsetsTy = cast<VectorType>(offsets .getType ());
825
847
if (offsetsTy.getRank () != 1 )
826
848
return rewriter.notifyMatchFailure (storeScatterOp,
827
849
" Expected 1D offsets vector" );
828
- VectorType storeVecTy =
829
- cast<VectorType>(storeScatterOp.getValue ().getType ());
850
+ VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
830
851
assert (storeVecTy.getRank () <= 2 &&
831
852
" Expected at most 2D result at SG level" );
832
853
VectorType distStoreVecTy;
@@ -837,80 +858,99 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
837
858
838
859
SmallVector<size_t > newRetIndices;
839
860
SmallVector<Value> operands = storeScatterOp->getOperands ();
840
- SmallVector<Type> operandTypes =
861
+ SmallVector<Type> operandTypesToYield =
841
862
llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes ());
842
- operandTypes[0 ] = distStoreVecTy;
843
- // Assume offset and mask pproducers will be distributed as well.
844
- operandTypes[2 ] = VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
845
- operandTypes[3 ] = VectorType::get (
863
+ operandTypesToYield[0 ] = distStoreVecTy;
864
+ // Assume offset and mask producers will be distributed as well.
865
+ operandTypesToYield[2 ] =
866
+ VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
867
+ operandTypesToYield[3 ] = VectorType::get (
846
868
{1 }, getElementTypeOrSelf (storeScatterOp.getMask ().getType ()));
847
869
848
870
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
849
- rewriter, warpOp, operands, operandTypes , newRetIndices);
871
+ rewriter, warpOp, operands, operandTypesToYield , newRetIndices);
850
872
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
851
873
newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
852
874
853
- auto loc = newWarpOp.getLoc ();
854
875
rewriter.setInsertionPointAfter (newWarpOp);
855
876
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
856
- rewriter, loc , TypeRange{}, newStoreScatterOpOperands,
877
+ rewriter, newWarpOp. getLoc () , TypeRange{}, newStoreScatterOpOperands,
857
878
storeScatterOp->getAttrs ());
858
879
xegpu::removeLayoutAttrs (newOp);
859
880
rewriter.eraseOp (storeScatterOp);
860
881
return success ();
861
882
}
862
883
};
863
884
885
+ // / Distribute a scattered load op. The logic and requirements are the same as
886
+ // / for the scattered store distribution. The warpOp's payload vector is
887
+ // / expected to be distributed by the load's result consumer.
888
+ // / Example 1 (no chunk size):
889
+ // / %mask = producer_op : vector<16xi1>
890
+ // / %offset = producer_op : vector<16xindex>
891
+ // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
892
+ // / vector<16xindex>, vector<16xi1> -> vector<16xf16>
893
+ // / To
894
+ // / %mask = producer_op : vector<1xi1>
895
+ // / %offset = producer_op : vector<1xindex>
896
+ // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
897
+ // / vector<1xindex>, vector<1xi1> -> vector<1xf16>
898
+ // / Example 2 (chunk size, same mask and offsets):
899
+ // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
900
+ // / memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
901
+ // / To
902
+ // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
903
+ // / memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
864
904
struct LoadDistribution final : public gpu::WarpDistributionPattern {
865
905
using gpu::WarpDistributionPattern::WarpDistributionPattern;
866
906
LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
867
907
PatternRewriter &rewriter) const override {
868
- OpOperand *yieldOperand = getWarpResult (warpOp, [&](Operation *op) {
869
- if (!isa<xegpu::LoadGatherOp>(op))
870
- return false ;
871
- auto yield = cast<gpu::YieldOp>(
872
- warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
873
- return yield->getPrevNode () == op;
908
+ OpOperand *producedByLastLoad = getWarpResult (warpOp, [&](Operation *op) {
909
+ // Check if the yield operand that was produced by the *last* scattered
910
+ // load op to avoid sinking it before barriers (maintain memory order).
911
+ return isa<xegpu::LoadGatherOp>(op) &&
912
+ warpOp.getTerminator ()->getPrevNode () == op;
874
913
});
875
- if (!yieldOperand )
914
+ if (!producedByLastLoad )
876
915
return rewriter.notifyMatchFailure (
877
- warpOp, " warp result is not a xegpu::LoadGatherOp op " );
916
+ warpOp, " The last op is not xegpu::LoadGatherOp" );
878
917
879
918
auto loadGatherOp =
880
- yieldOperand ->get ().getDefiningOp <xegpu::LoadGatherOp>();
881
- if (! loadGatherOp.getOffsets ())
882
- return rewriter. notifyMatchFailure (loadGatherOp,
883
- " Load op must have offsets argument " );
884
- VectorType offsetsTy =
885
- cast<VectorType>(loadGatherOp. getOffsets () .getType ());
919
+ producedByLastLoad ->get ().getDefiningOp <xegpu::LoadGatherOp>();
920
+ auto offsets = loadGatherOp.getOffsets ();
921
+ if (!offsets || !isa<VectorType>(offsets. getType ()))
922
+ return rewriter. notifyMatchFailure (
923
+ loadGatherOp, " Load op must have a vector of offsets argument " );
924
+ VectorType offsetsTy = cast<VectorType>(offsets .getType ());
886
925
if (offsetsTy.getRank () != 1 )
887
926
return rewriter.notifyMatchFailure (loadGatherOp,
888
927
" Expected 1D offsets vector" );
889
928
890
929
SmallVector<size_t > newRetIndices;
891
930
SmallVector<Value> operands = loadGatherOp->getOperands ();
892
- SmallVector<Type> operandTypes =
931
+ SmallVector<Type> operandTypesToYield =
893
932
llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes ());
894
- // Assume offset and mask pproducers will be distributed as well.
895
- operandTypes[1 ] = VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
896
- operandTypes[2 ] = VectorType::get (
897
- {1 }, getElementTypeOrSelf (loadGatherOp.getMask ().getType ()));
933
+ // Assume offset and mask producers will be distributed as well.
934
+ operandTypesToYield[1 ] =
935
+ VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
936
+ operandTypesToYield[2 ] =
937
+ VectorType::get ({1 }, getElementTypeOrSelf (loadGatherOp.getMaskType ()));
898
938
899
939
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
900
- rewriter, warpOp, operands, operandTypes , newRetIndices);
940
+ rewriter, warpOp, operands, operandTypesToYield , newRetIndices);
901
941
902
942
SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector (
903
943
newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
904
944
905
- const unsigned operandIdx = yieldOperand ->getOperandNumber ();
945
+ const unsigned operandIdx = producedByLastLoad ->getOperandNumber ();
906
946
VectorType loadVecTy =
907
947
cast<VectorType>(warpOp.getResult (operandIdx).getType ());
908
948
assert (loadVecTy.getRank () == 1 && " Expected a distributed vector" );
909
949
910
- auto loc = newWarpOp.getLoc ();
911
950
rewriter.setInsertionPointAfter (newWarpOp);
912
951
xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
913
- loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs ());
952
+ newWarpOp.getLoc (), loadVecTy, newLoadGatherOperands,
953
+ loadGatherOp->getAttrs ());
914
954
Value distributedVal = newWarpOp.getResult (operandIdx);
915
955
rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
916
956
return success ();
@@ -948,6 +988,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
948
988
if (!isa<VectorType>(operand.get ().getType ()))
949
989
continue ;
950
990
991
+ // Vectors operands of these ops have a fixed and implicit layout.
951
992
if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
952
993
continue ;
953
994
auto layout =
0 commit comments