@@ -807,200 +807,6 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807
807
}
808
808
};
809
809
810
- // / Distribute a scattered store op. The offsets argument is required.
811
- // / Both offset and mask vectors must be 1D and have #subgroup_size elements.
812
- // / The layouts are fixed and implicit: one offset/mask per lane.
813
- // / The pass changes the offset/mask vector shapes to a
814
- // / single-element vector, **it is assumed that their producer will also be
815
- // / distributed**. The payload vector also has a fixed distribution:
816
- // / no chunk size -> vector of one element.
817
- // / chunk size -> vector of the innermost dimension of the SG-payload.
818
- // / Example 1 (no chunk size):
819
- // / %mask = producer_op : vector<16xi1>
820
- // / %offset = producer_op : vector<16xindex>
821
- // / xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822
- // / memref<256xf16>, vector<16xindex>, vector<16xi1>
823
- // / To
824
- // / %mask = producer_op : vector<1xi1>
825
- // / %offset = producer_op : vector<1xindex>
826
- // / xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827
- // / memref<256xf16>, vector<1xindex>, vector<1xi1>
828
- // / Example 2 (chunk size, same mask and offsets):
829
- // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830
- // / vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831
- // / To
832
- // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833
- // / vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
834
- struct StoreDistribution final : public gpu::WarpDistributionPattern {
835
- using gpu::WarpDistributionPattern::WarpDistributionPattern;
836
- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
837
- PatternRewriter &rewriter) const override {
838
- Operation *lastNode = warpOp.getTerminator ()->getPrevNode ();
839
- auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
840
- if (!storeScatterOp)
841
- return failure ();
842
- auto offsets = storeScatterOp.getOffsets ();
843
- if (!offsets || !isa<VectorType>(offsets.getType ()))
844
- return rewriter.notifyMatchFailure (
845
- storeScatterOp, " Store op must have a vector of offsets argument" );
846
- VectorType offsetsTy = cast<VectorType>(offsets.getType ());
847
- VectorType maskTy = cast<VectorType>(storeScatterOp.getMask ().getType ());
848
- if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
849
- return rewriter.notifyMatchFailure (storeScatterOp,
850
- " Expected 1D offsets and mask vector" );
851
- VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
852
- if (storeVecTy.getRank () > 2 )
853
- return rewriter.notifyMatchFailure (
854
- storeScatterOp, " Expected at most 2D result at SG level" );
855
-
856
- std::string layoutPayloadName =
857
- xegpu::getLayoutName (storeScatterOp->getOpOperand (0 ));
858
- std::string layoutOffsetsName =
859
- xegpu::getLayoutName (storeScatterOp->getOpOperand (2 ));
860
- std::string layoutMaskName =
861
- xegpu::getLayoutName (storeScatterOp->getOpOperand (3 ));
862
-
863
- xegpu::LayoutAttr layoutPayload =
864
- storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutPayloadName);
865
- xegpu::LayoutAttr layoutOffsets =
866
- storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
867
- xegpu::LayoutAttr layoutMask =
868
- storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
869
-
870
- FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
871
- getDistVecTypeBasedOnLaneLayout (layoutPayload, storeVecTy);
872
- FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
873
- getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
874
- FailureOr<VectorType> distMaskByWarpOpOrFailure =
875
- getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
876
- if (failed (distStoreVecByWarpOpOrFailure) ||
877
- failed (distOffsetsByWarpOpOrFailure) ||
878
- failed (distMaskByWarpOpOrFailure)) {
879
- return rewriter.notifyMatchFailure (
880
- storeScatterOp,
881
- " Some vector operands have no layouts, using defaults instead." );
882
- }
883
- VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value ();
884
- VectorType expectedPayloadTy = VectorType::get (
885
- {distPayloadTy.getNumElements ()}, distPayloadTy.getElementType ());
886
-
887
- SmallVector<size_t > newRetIndices;
888
- SmallVector<Value> operands = storeScatterOp->getOperands ();
889
- SmallVector<Type> operandTypesToYield = {
890
- expectedPayloadTy, operands[1 ].getType (),
891
- distOffsetsByWarpOpOrFailure.value (),
892
- distMaskByWarpOpOrFailure.value ()};
893
-
894
- gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
895
- rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
896
- SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
897
- newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
898
-
899
- rewriter.setInsertionPointAfter (newWarpOp);
900
- xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
901
- rewriter, newWarpOp.getLoc (), TypeRange{}, newStoreScatterOpOperands,
902
- storeScatterOp->getAttrs ());
903
- xegpu::removeLayoutAttrs (newOp);
904
- rewriter.eraseOp (storeScatterOp);
905
- return success ();
906
- }
907
- };
908
-
909
- // / Distribute a scattered load op. The logic and requirements are the same as
910
- // / for the scattered store distribution. The warpOp's payload vector is
911
- // / expected to be distributed by the load's result consumer.
912
- // / Example 1 (no chunk size):
913
- // / %mask = producer_op : vector<16xi1>
914
- // / %offset = producer_op : vector<16xindex>
915
- // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
916
- // / vector<16xindex>, vector<16xi1> -> vector<16xf16>
917
- // / To
918
- // / %mask = producer_op : vector<1xi1>
919
- // / %offset = producer_op : vector<1xindex>
920
- // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
921
- // / vector<1xindex>, vector<1xi1> -> vector<1xf16>
922
- // / Example 2 (chunk size, same mask and offsets):
923
- // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
924
- // / memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
925
- // / To
926
- // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
927
- // / memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
928
- struct LoadDistribution final : public gpu::WarpDistributionPattern {
929
- using gpu::WarpDistributionPattern::WarpDistributionPattern;
930
- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
931
- PatternRewriter &rewriter) const override {
932
- OpOperand *producedByLastLoad = getWarpResult (warpOp, [&](Operation *op) {
933
- // Check if the yield operand that was produced by the *last* scattered
934
- // load op to avoid sinking it before barriers (maintain memory order).
935
- return isa<xegpu::LoadGatherOp>(op) &&
936
- warpOp.getTerminator ()->getPrevNode () == op;
937
- });
938
- if (!producedByLastLoad)
939
- return rewriter.notifyMatchFailure (
940
- warpOp, " The last op is not xegpu::LoadGatherOp" );
941
-
942
- auto loadGatherOp =
943
- producedByLastLoad->get ().getDefiningOp <xegpu::LoadGatherOp>();
944
- auto offsets = loadGatherOp.getOffsets ();
945
- if (!offsets || !isa<VectorType>(offsets.getType ()) ||
946
- !isa<VectorType>(loadGatherOp.getMask ().getType ()))
947
- return rewriter.notifyMatchFailure (
948
- loadGatherOp,
949
- " Load op must have a vector arguments for offsets and mask" );
950
- VectorType offsetsTy = cast<VectorType>(offsets.getType ());
951
- VectorType maskTy = cast<VectorType>(loadGatherOp.getMask ().getType ());
952
- if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
953
- return rewriter.notifyMatchFailure (loadGatherOp,
954
- " Expected 1D offsets and mask vector" );
955
- // Assume offset and mask producers will be distributed as well.
956
- std::string layoutOffsetsName =
957
- xegpu::getLayoutName (loadGatherOp->getOpOperand (1 ));
958
- std::string layoutMaskName =
959
- xegpu::getLayoutName (loadGatherOp->getOpOperand (2 ));
960
-
961
- xegpu::LayoutAttr layoutOffsets =
962
- loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
963
- xegpu::LayoutAttr layoutMask =
964
- loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
965
-
966
- FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
967
- getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
968
- FailureOr<VectorType> distMaskByWarpOpOrFailure =
969
- getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
970
- if (failed (distOffsetsByWarpOpOrFailure) ||
971
- failed (distMaskByWarpOpOrFailure)) {
972
- return rewriter.notifyMatchFailure (
973
- loadGatherOp,
974
- " Some vector operands have no layouts, using defaults instead." );
975
- }
976
-
977
- SmallVector<size_t > newRetIndices;
978
- SmallVector<Value> operands = loadGatherOp->getOperands ();
979
- SmallVector<Type> operandTypesToYield = {
980
- operands[0 ].getType (), distOffsetsByWarpOpOrFailure.value (),
981
- distMaskByWarpOpOrFailure.value ()};
982
-
983
- gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
984
- rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
985
-
986
- SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector (
987
- newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
988
-
989
- const unsigned operandIdx = producedByLastLoad->getOperandNumber ();
990
- VectorType loadVecTy =
991
- cast<VectorType>(warpOp.getResult (operandIdx).getType ());
992
-
993
- rewriter.setInsertionPointAfter (newWarpOp);
994
- xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
995
- newWarpOp.getLoc (), loadVecTy, newLoadGatherOperands,
996
- loadGatherOp->getAttrs ());
997
- xegpu::removeLayoutAttrs (newOp);
998
- Value distributedVal = newWarpOp.getResult (operandIdx);
999
- rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
1000
- return success ();
1001
- }
1002
- };
1003
-
1004
810
} // namespace
1005
811
1006
812
namespace {
@@ -1013,11 +819,10 @@ struct XeGPUSubgroupDistributePass final
1013
819
1014
820
void xegpu::populateXeGPUSubgroupDistributePatterns (
1015
821
RewritePatternSet &patterns) {
1016
- patterns
1017
- .add <CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1018
- DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
1019
- GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
1020
- patterns.getContext ());
822
+ patterns.add <CreateNdDescDistribution, StoreNdDistribution,
823
+ LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824
+ UpdateNdOffsetDistribution, GpuBarrierDistribution>(
825
+ patterns.getContext ());
1021
826
}
1022
827
1023
828
void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments