@@ -807,6 +807,200 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807
807
}
808
808
};
809
809
810
+ // / Distribute a scattered store op. The offsets argument is required.
811
+ // / Both offset and mask vectors must be 1D and have #subgroup_size elements.
812
+ // / The layouts are fixed and implicit: one offset/mask per lane.
813
+ // / The pass changes the offset/mask vector shapes to a
814
+ // / single-element vector, **it is assumed that their producer will also be
815
+ // / distributed**. The payload vector also has a fixed distribution:
816
+ // / no chunk size -> vector of one element.
817
+ // / chunk size -> vector of the innermost dimension of the SG-payload.
818
+ // / Example 1 (no chunk size):
819
+ // / %mask = producer_op : vector<16xi1>
820
+ // / %offset = producer_op : vector<16xindex>
821
+ // / xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822
+ // / memref<256xf16>, vector<16xindex>, vector<16xi1>
823
+ // / To
824
+ // / %mask = producer_op : vector<1xi1>
825
+ // / %offset = producer_op : vector<1xindex>
826
+ // / xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827
+ // / memref<256xf16>, vector<1xindex>, vector<1xi1>
828
+ // / Example 2 (chunk size, same mask and offsets):
829
+ // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830
+ // / vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831
+ // / To
832
+ // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833
+ // / vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
834
+ struct StoreDistribution final : public gpu::WarpDistributionPattern {
835
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
836
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
837
+ PatternRewriter &rewriter) const override {
838
+ Operation *lastNode = warpOp.getTerminator ()->getPrevNode ();
839
+ auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
840
+ if (!storeScatterOp)
841
+ return failure ();
842
+ auto offsets = storeScatterOp.getOffsets ();
843
+ if (!offsets || !isa<VectorType>(offsets.getType ()))
844
+ return rewriter.notifyMatchFailure (
845
+ storeScatterOp, " Store op must have a vector of offsets argument" );
846
+ VectorType offsetsTy = cast<VectorType>(offsets.getType ());
847
+ VectorType maskTy = cast<VectorType>(storeScatterOp.getMask ().getType ());
848
+ if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
849
+ return rewriter.notifyMatchFailure (storeScatterOp,
850
+ " Expected 1D offsets and mask vector" );
851
+ VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
852
+ if (storeVecTy.getRank () > 2 )
853
+ return rewriter.notifyMatchFailure (
854
+ storeScatterOp, " Expected at most 2D result at SG level" );
855
+
856
+ std::string layoutPayloadName =
857
+ xegpu::getLayoutName (storeScatterOp->getOpOperand (0 ));
858
+ std::string layoutOffsetsName =
859
+ xegpu::getLayoutName (storeScatterOp->getOpOperand (2 ));
860
+ std::string layoutMaskName =
861
+ xegpu::getLayoutName (storeScatterOp->getOpOperand (3 ));
862
+
863
+ xegpu::LayoutAttr layoutPayload =
864
+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutPayloadName);
865
+ xegpu::LayoutAttr layoutOffsets =
866
+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
867
+ xegpu::LayoutAttr layoutMask =
868
+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
869
+
870
+ FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
871
+ getDistVecTypeBasedOnLaneLayout (layoutPayload, storeVecTy);
872
+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
873
+ getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
874
+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
875
+ getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
876
+ if (failed (distStoreVecByWarpOpOrFailure) ||
877
+ failed (distOffsetsByWarpOpOrFailure) ||
878
+ failed (distMaskByWarpOpOrFailure)) {
879
+ return rewriter.notifyMatchFailure (
880
+ storeScatterOp,
881
+ " Some vector operands have no layouts, using defaults instead." );
882
+ }
883
+ VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value ();
884
+ VectorType expectedPayloadTy = VectorType::get (
885
+ {distPayloadTy.getNumElements ()}, distPayloadTy.getElementType ());
886
+
887
+ SmallVector<size_t > newRetIndices;
888
+ SmallVector<Value> operands = storeScatterOp->getOperands ();
889
+ SmallVector<Type> operandTypesToYield = {
890
+ expectedPayloadTy, operands[1 ].getType (),
891
+ distOffsetsByWarpOpOrFailure.value (),
892
+ distMaskByWarpOpOrFailure.value ()};
893
+
894
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
895
+ rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
896
+ SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
897
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
898
+
899
+ rewriter.setInsertionPointAfter (newWarpOp);
900
+ xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
901
+ rewriter, newWarpOp.getLoc (), TypeRange{}, newStoreScatterOpOperands,
902
+ storeScatterOp->getAttrs ());
903
+ xegpu::removeLayoutAttrs (newOp);
904
+ rewriter.eraseOp (storeScatterOp);
905
+ return success ();
906
+ }
907
+ };
908
+
909
+ // / Distribute a scattered load op. The logic and requirements are the same as
910
+ // / for the scattered store distribution. The warpOp's payload vector is
911
+ // / expected to be distributed by the load's result consumer.
912
+ // / Example 1 (no chunk size):
913
+ // / %mask = producer_op : vector<16xi1>
914
+ // / %offset = producer_op : vector<16xindex>
915
+ // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
916
+ // / vector<16xindex>, vector<16xi1> -> vector<16xf16>
917
+ // / To
918
+ // / %mask = producer_op : vector<1xi1>
919
+ // / %offset = producer_op : vector<1xindex>
920
+ // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
921
+ // / vector<1xindex>, vector<1xi1> -> vector<1xf16>
922
+ // / Example 2 (chunk size, same mask and offsets):
923
+ // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
924
+ // / memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
925
+ // / To
926
+ // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
927
+ // / memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
928
+ struct LoadDistribution final : public gpu::WarpDistributionPattern {
929
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
930
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
931
+ PatternRewriter &rewriter) const override {
932
+ OpOperand *producedByLastLoad = getWarpResult (warpOp, [&](Operation *op) {
933
+ // Check if the yield operand that was produced by the *last* scattered
934
+ // load op to avoid sinking it before barriers (maintain memory order).
935
+ return isa<xegpu::LoadGatherOp>(op) &&
936
+ warpOp.getTerminator ()->getPrevNode () == op;
937
+ });
938
+ if (!producedByLastLoad)
939
+ return rewriter.notifyMatchFailure (
940
+ warpOp, " The last op is not xegpu::LoadGatherOp" );
941
+
942
+ auto loadGatherOp =
943
+ producedByLastLoad->get ().getDefiningOp <xegpu::LoadGatherOp>();
944
+ auto offsets = loadGatherOp.getOffsets ();
945
+ if (!offsets || !isa<VectorType>(offsets.getType ()) ||
946
+ !isa<VectorType>(loadGatherOp.getMask ().getType ()))
947
+ return rewriter.notifyMatchFailure (
948
+ loadGatherOp,
949
+ " Load op must have a vector arguments for offsets and mask" );
950
+ VectorType offsetsTy = cast<VectorType>(offsets.getType ());
951
+ VectorType maskTy = cast<VectorType>(loadGatherOp.getMask ().getType ());
952
+ if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
953
+ return rewriter.notifyMatchFailure (loadGatherOp,
954
+ " Expected 1D offsets and mask vector" );
955
+ // Assume offset and mask producers will be distributed as well.
956
+ std::string layoutOffsetsName =
957
+ xegpu::getLayoutName (loadGatherOp->getOpOperand (1 ));
958
+ std::string layoutMaskName =
959
+ xegpu::getLayoutName (loadGatherOp->getOpOperand (2 ));
960
+
961
+ xegpu::LayoutAttr layoutOffsets =
962
+ loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
963
+ xegpu::LayoutAttr layoutMask =
964
+ loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
965
+
966
+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
967
+ getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
968
+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
969
+ getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
970
+ if (failed (distOffsetsByWarpOpOrFailure) ||
971
+ failed (distMaskByWarpOpOrFailure)) {
972
+ return rewriter.notifyMatchFailure (
973
+ loadGatherOp,
974
+ " Some vector operands have no layouts, using defaults instead." );
975
+ }
976
+
977
+ SmallVector<size_t > newRetIndices;
978
+ SmallVector<Value> operands = loadGatherOp->getOperands ();
979
+ SmallVector<Type> operandTypesToYield = {
980
+ operands[0 ].getType (), distOffsetsByWarpOpOrFailure.value (),
981
+ distMaskByWarpOpOrFailure.value ()};
982
+
983
+ const unsigned operandIdx = producedByLastLoad->getOperandNumber ();
984
+ VectorType loadVecTy =
985
+ cast<VectorType>(warpOp.getResult (operandIdx).getType ());
986
+
987
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
988
+ rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
989
+
990
+ SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector (
991
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
992
+
993
+ rewriter.setInsertionPointAfter (newWarpOp);
994
+ xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
995
+ newWarpOp.getLoc (), loadVecTy, newLoadGatherOperands,
996
+ loadGatherOp->getAttrs ());
997
+ xegpu::removeLayoutAttrs (newOp);
998
+ Value distributedVal = newWarpOp.getResult (operandIdx);
999
+ rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
1000
+ return success ();
1001
+ }
1002
+ };
1003
+
810
1004
} // namespace
811
1005
812
1006
namespace {
@@ -819,10 +1013,11 @@ struct XeGPUSubgroupDistributePass final
819
1013
820
1014
void xegpu::populateXeGPUSubgroupDistributePatterns (
821
1015
RewritePatternSet &patterns) {
822
- patterns.add <CreateNdDescDistribution, StoreNdDistribution,
823
- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824
- UpdateNdOffsetDistribution, GpuBarrierDistribution>(
825
- patterns.getContext ());
1016
+ patterns
1017
+ .add <CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1018
+ DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
1019
+ GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
1020
+ patterns.getContext ());
826
1021
}
827
1022
828
1023
void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments