@@ -817,14 +817,14 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
817
817
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
818
818
if (!storeScatterOp)
819
819
return failure ();
820
- else if (!storeScatterOp.getOffsets ())
820
+ if (!storeScatterOp.getOffsets ())
821
821
return rewriter.notifyMatchFailure (storeScatterOp,
822
822
" Store op must have offsets argument" );
823
- else if (cast<VectorType>(storeScatterOp.getOffsets ().getType ())
824
- .getRank () != 1 )
823
+ VectorType offsetsTy =
824
+ cast<VectorType>(storeScatterOp.getOffsets ().getType ());
825
+ if (offsetsTy.getRank () != 1 )
825
826
return rewriter.notifyMatchFailure (storeScatterOp,
826
827
" Expected 1D offsets vector" );
827
-
828
828
VectorType storeVecTy =
829
829
cast<VectorType>(storeScatterOp.getValue ().getType ());
830
830
assert (storeVecTy.getRank () <= 2 &&
@@ -836,33 +836,22 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
836
836
distStoreVecTy = VectorType::Builder (storeVecTy).setDim (0 , 1 );
837
837
838
838
SmallVector<size_t > newRetIndices;
839
- SmallVector<Value> operands =
840
- llvm::to_vector_of<Value>(storeScatterOp->getOperands ());
839
+ SmallVector<Value> operands = storeScatterOp->getOperands ();
841
840
SmallVector<Type> operandTypes =
842
841
llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes ());
843
842
operandTypes[0 ] = distStoreVecTy;
843
+ // Assume offset and mask pproducers will be distributed as well.
844
+ operandTypes[2 ] = VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
845
+ operandTypes[3 ] = VectorType::get (
846
+ {1 }, getElementTypeOrSelf (storeScatterOp.getMask ().getType ()));
844
847
845
848
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
846
849
rewriter, warpOp, operands, operandTypes, newRetIndices);
847
850
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
848
851
newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
849
852
850
- Value offsetsVec = newStoreScatterOpOperands[2 ];
851
- Value maskVec = newStoreScatterOpOperands[3 ];
852
-
853
853
auto loc = newWarpOp.getLoc ();
854
- Value laneId = warpOp.getLaneid ();
855
854
rewriter.setInsertionPointAfter (newWarpOp);
856
- Value laneOffset =
857
- vector::ExtractOp::create (rewriter, loc, offsetsVec, laneId);
858
- laneOffset = vector::BroadcastOp::create (
859
- rewriter, loc, VectorType::get ({1 }, laneOffset.getType ()), laneOffset);
860
- Value laneMask = vector::ExtractOp::create (rewriter, loc, maskVec, laneId);
861
- laneMask = vector::BroadcastOp::create (
862
- rewriter, loc, VectorType::get ({1 }, laneMask.getType ()), laneMask);
863
- newStoreScatterOpOperands[2 ] = laneOffset;
864
- newStoreScatterOpOperands[3 ] = laneMask;
865
-
866
855
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
867
856
rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
868
857
storeScatterOp->getAttrs ());
@@ -892,16 +881,20 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
892
881
if (!loadGatherOp.getOffsets ())
893
882
return rewriter.notifyMatchFailure (loadGatherOp,
894
883
" Load op must have offsets argument" );
895
- else if (cast<VectorType>(loadGatherOp.getOffsets ().getType ()).getRank () !=
896
- 1 )
884
+ VectorType offsetsTy =
885
+ cast<VectorType>(loadGatherOp.getOffsets ().getType ());
886
+ if (offsetsTy.getRank () != 1 )
897
887
return rewriter.notifyMatchFailure (loadGatherOp,
898
888
" Expected 1D offsets vector" );
899
889
900
890
SmallVector<size_t > newRetIndices;
901
- SmallVector<Value> operands =
902
- llvm::to_vector_of<Value>(loadGatherOp->getOperands ());
891
+ SmallVector<Value> operands = loadGatherOp->getOperands ();
903
892
SmallVector<Type> operandTypes =
904
893
llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes ());
894
+ // Assume offset and mask pproducers will be distributed as well.
895
+ operandTypes[1 ] = VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
896
+ operandTypes[2 ] = VectorType::get (
897
+ {1 }, getElementTypeOrSelf (loadGatherOp.getMask ().getType ()));
905
898
906
899
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
907
900
rewriter, warpOp, operands, operandTypes, newRetIndices);
@@ -914,21 +907,8 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
914
907
cast<VectorType>(warpOp.getResult (operandIdx).getType ());
915
908
assert (loadVecTy.getRank () == 1 && " Expected a distributed vector" );
916
909
917
- Value offsetsVec = newLoadGatherOperands[1 ];
918
- Value maskVec = newLoadGatherOperands[2 ];
919
910
auto loc = newWarpOp.getLoc ();
920
- Value laneId = warpOp.getLaneid ();
921
911
rewriter.setInsertionPointAfter (newWarpOp);
922
- Value laneOffset =
923
- vector::ExtractOp::create (rewriter, loc, offsetsVec, laneId);
924
- laneOffset = vector::BroadcastOp::create (
925
- rewriter, loc, VectorType::get ({1 }, laneOffset.getType ()), laneOffset);
926
- Value laneMask = vector::ExtractOp::create (rewriter, loc, maskVec, laneId);
927
- laneMask = vector::BroadcastOp::create (
928
- rewriter, loc, VectorType::get ({1 }, laneMask.getType ()), laneMask);
929
- newLoadGatherOperands[1 ] = laneOffset;
930
- newLoadGatherOperands[2 ] = laneMask;
931
-
932
912
xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
933
913
loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs ());
934
914
Value distributedVal = newWarpOp.getResult (operandIdx);
0 commit comments