@@ -844,9 +844,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
844
844
return rewriter.notifyMatchFailure (
845
845
storeScatterOp, " Store op must have a vector of offsets argument" );
846
846
VectorType offsetsTy = cast<VectorType>(offsets.getType ());
847
- if (offsetsTy.getRank () != 1 )
847
+ VectorType maskTy = cast<VectorType>(storeScatterOp.getMask ().getType ());
848
+ if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
848
849
return rewriter.notifyMatchFailure (storeScatterOp,
849
- " Expected 1D offsets vector" );
850
+ " Expected 1D offsets and mask vector" );
850
851
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
851
852
assert (storeVecTy.getRank () <= 2 &&
852
853
" Expected at most 2D result at SG level" );
@@ -855,17 +856,45 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
855
856
distStoreVecTy = VectorType::Builder (storeVecTy).dropDim (0 );
856
857
else // rank 1
857
858
distStoreVecTy = VectorType::Builder (storeVecTy).setDim (0 , 1 );
858
-
859
- SmallVector<size_t > newRetIndices;
860
- SmallVector<Value> operands = storeScatterOp->getOperands ();
861
- SmallVector<Type> operandTypesToYield =
862
- llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes ());
863
- operandTypesToYield[0 ] = distStoreVecTy;
864
859
// Assume offset and mask producers will be distributed as well.
865
- operandTypesToYield[ 2 ] =
860
+ VectorType distOffsetsTy =
866
861
VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
867
- operandTypesToYield[ 3 ] = VectorType::get (
862
+ VectorType distMaskTy = VectorType::get (
868
863
{1 }, getElementTypeOrSelf (storeScatterOp.getMask ().getType ()));
864
+ std::string layoutPayloadName =
865
+ xegpu::getLayoutName (storeScatterOp->getOpOperand (0 ));
866
+ std::string layoutOffsetsName =
867
+ xegpu::getLayoutName (storeScatterOp->getOpOperand (2 ));
868
+ std::string layoutMaskName =
869
+ xegpu::getLayoutName (storeScatterOp->getOpOperand (3 ));
870
+
871
+ xegpu::LayoutAttr layoutPayload =
872
+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutPayloadName);
873
+ xegpu::LayoutAttr layoutOffsets =
874
+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
875
+ xegpu::LayoutAttr layoutMask =
876
+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
877
+
878
+ FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
879
+ getDistVecTypeBasedOnLaneLayout (layoutPayload, storeVecTy);
880
+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
881
+ getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
882
+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
883
+ getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
884
+ if (failed (distStoreVecByWarpOpOrFailure) ||
885
+ failed (distOffsetsByWarpOpOrFailure) ||
886
+ failed (distMaskByWarpOpOrFailure)) {
887
+ storeScatterOp.emitWarning (
888
+ " Some vector operands have no layouts, using defaults instead." );
889
+ }
890
+ distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or (distStoreVecTy);
891
+ distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or (distOffsetsTy);
892
+ distMaskTy = distMaskByWarpOpOrFailure.value_or (distMaskTy);
893
+
894
+ SmallVector<size_t > newRetIndices;
895
+ SmallVector<Value> operands = storeScatterOp->getOperands ();
896
+ SmallVector<Type> operandTypesToYield = {
897
+ distStoreVecTy, operands[1 ].getType (), distOffsetsTy, distMaskTy};
869
898
870
899
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
871
900
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -918,23 +947,47 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
918
947
auto loadGatherOp =
919
948
producedByLastLoad->get ().getDefiningOp <xegpu::LoadGatherOp>();
920
949
auto offsets = loadGatherOp.getOffsets ();
921
- if (!offsets || !isa<VectorType>(offsets.getType ()))
950
+ if (!offsets || !isa<VectorType>(offsets.getType ()) ||
951
+ !isa<VectorType>(loadGatherOp.getMask ().getType ()))
922
952
return rewriter.notifyMatchFailure (
923
- loadGatherOp, " Load op must have a vector of offsets argument" );
953
+ loadGatherOp,
954
+ " Load op must have a vector arguments for offsets and mask" );
924
955
VectorType offsetsTy = cast<VectorType>(offsets.getType ());
925
- if (offsetsTy.getRank () != 1 )
956
+ VectorType maskTy = cast<VectorType>(loadGatherOp.getMask ().getType ());
957
+ if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
926
958
return rewriter.notifyMatchFailure (loadGatherOp,
927
- " Expected 1D offsets vector" );
959
+ " Expected 1D offsets and mask vector" );
960
+ // Assume offset and mask producers will be distributed as well.
961
+ VectorType distOffsetsTy =
962
+ VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
963
+ VectorType distMaskTy = VectorType::get ({1 }, getElementTypeOrSelf (maskTy));
964
+
965
+ std::string layoutOffsetsName =
966
+ xegpu::getLayoutName (loadGatherOp->getOpOperand (1 ));
967
+ std::string layoutMaskName =
968
+ xegpu::getLayoutName (loadGatherOp->getOpOperand (2 ));
969
+
970
+ xegpu::LayoutAttr layoutOffsets =
971
+ loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
972
+ xegpu::LayoutAttr layoutMask =
973
+ loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
974
+
975
+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
976
+ getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
977
+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
978
+ getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
979
+ if (failed (distOffsetsByWarpOpOrFailure) ||
980
+ failed (distMaskByWarpOpOrFailure)) {
981
+ loadGatherOp.emitWarning (
982
+ " Some vector operands have no layouts, using defaults instead." );
983
+ }
984
+ distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or (distOffsetsTy);
985
+ distMaskTy = distMaskByWarpOpOrFailure.value_or (distMaskTy);
928
986
929
987
SmallVector<size_t > newRetIndices;
930
988
SmallVector<Value> operands = loadGatherOp->getOperands ();
931
- SmallVector<Type> operandTypesToYield =
932
- llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes ());
933
- // Assume offset and mask producers will be distributed as well.
934
- operandTypesToYield[1 ] =
935
- VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
936
- operandTypesToYield[2 ] =
937
- VectorType::get ({1 }, getElementTypeOrSelf (loadGatherOp.getMaskType ()));
989
+ SmallVector<Type> operandTypesToYield = {operands[0 ].getType (),
990
+ distOffsetsTy, distMaskTy};
938
991
939
992
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
940
993
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -951,6 +1004,7 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
951
1004
xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
952
1005
newWarpOp.getLoc (), loadVecTy, newLoadGatherOperands,
953
1006
loadGatherOp->getAttrs ());
1007
+ xegpu::removeLayoutAttrs (newOp);
954
1008
Value distributedVal = newWarpOp.getResult (operandIdx);
955
1009
rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
956
1010
return success ();
@@ -990,7 +1044,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
990
1044
991
1045
// Vectors operands of these ops have a fixed and implicit layout.
992
1046
if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
993
- continue ;
1047
+ continue ;
994
1048
auto layout =
995
1049
xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
996
1050
if (!layout) {
0 commit comments