@@ -849,18 +849,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
849
849
return rewriter.notifyMatchFailure (storeScatterOp,
850
850
" Expected 1D offsets and mask vector" );
851
851
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
852
- assert (storeVecTy.getRank () <= 2 &&
853
- " Expected at most 2D result at SG level" );
854
- VectorType distStoreVecTy;
855
- if (storeVecTy.getRank () == 2 )
856
- distStoreVecTy = VectorType::Builder (storeVecTy).dropDim (0 );
857
- else // rank 1
858
- distStoreVecTy = VectorType::Builder (storeVecTy).setDim (0 , 1 );
859
- // Assume offset and mask producers will be distributed as well.
860
- VectorType distOffsetsTy =
861
- VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
862
- VectorType distMaskTy = VectorType::get (
863
- {1 }, getElementTypeOrSelf (storeScatterOp.getMask ().getType ()));
852
+ if (storeVecTy.getRank () > 2 )
853
+ return rewriter.notifyMatchFailure (
854
+ storeScatterOp, " Expected at most 2D result at SG level" );
855
+
864
856
std::string layoutPayloadName =
865
857
xegpu::getLayoutName (storeScatterOp->getOpOperand (0 ));
866
858
std::string layoutOffsetsName =
@@ -884,17 +876,20 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
884
876
if (failed (distStoreVecByWarpOpOrFailure) ||
885
877
failed (distOffsetsByWarpOpOrFailure) ||
886
878
failed (distMaskByWarpOpOrFailure)) {
887
- storeScatterOp.emitWarning (
879
+ return rewriter.notifyMatchFailure (
880
+ storeScatterOp,
888
881
" Some vector operands have no layouts, using defaults instead." );
889
882
}
890
- distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or (distStoreVecTy );
891
- distOffsetsTy = distOffsetsByWarpOpOrFailure. value_or (distOffsetsTy);
892
- distMaskTy = distMaskByWarpOpOrFailure. value_or (distMaskTy );
883
+ VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value ( );
884
+ VectorType expectedPayloadTy = VectorType::get (
885
+ {distPayloadTy. getNumElements ()}, distPayloadTy. getElementType () );
893
886
894
887
SmallVector<size_t > newRetIndices;
895
888
SmallVector<Value> operands = storeScatterOp->getOperands ();
896
889
SmallVector<Type> operandTypesToYield = {
897
- distStoreVecTy, operands[1 ].getType (), distOffsetsTy, distMaskTy};
890
+ expectedPayloadTy, operands[1 ].getType (),
891
+ distOffsetsByWarpOpOrFailure.value (),
892
+ distMaskByWarpOpOrFailure.value ()};
898
893
899
894
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
900
895
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -958,10 +953,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
958
953
return rewriter.notifyMatchFailure (loadGatherOp,
959
954
" Expected 1D offsets and mask vector" );
960
955
// Assume offset and mask producers will be distributed as well.
961
- VectorType distOffsetsTy =
962
- VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
963
- VectorType distMaskTy = VectorType::get ({1 }, getElementTypeOrSelf (maskTy));
964
-
965
956
std::string layoutOffsetsName =
966
957
xegpu::getLayoutName (loadGatherOp->getOpOperand (1 ));
967
958
std::string layoutMaskName =
@@ -978,16 +969,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
978
969
getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
979
970
if (failed (distOffsetsByWarpOpOrFailure) ||
980
971
failed (distMaskByWarpOpOrFailure)) {
981
- loadGatherOp.emitWarning (
972
+ return rewriter.notifyMatchFailure (
973
+ loadGatherOp,
982
974
" Some vector operands have no layouts, using defaults instead." );
983
975
}
984
- distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or (distOffsetsTy);
985
- distMaskTy = distMaskByWarpOpOrFailure.value_or (distMaskTy);
986
976
987
977
SmallVector<size_t > newRetIndices;
988
978
SmallVector<Value> operands = loadGatherOp->getOperands ();
989
- SmallVector<Type> operandTypesToYield = {operands[0 ].getType (),
990
- distOffsetsTy, distMaskTy};
979
+ SmallVector<Type> operandTypesToYield = {
980
+ operands[0 ].getType (), distOffsetsByWarpOpOrFailure.value (),
981
+ distMaskByWarpOpOrFailure.value ()};
991
982
992
983
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
993
984
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -998,7 +989,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
998
989
const unsigned operandIdx = producedByLastLoad->getOperandNumber ();
999
990
VectorType loadVecTy =
1000
991
cast<VectorType>(warpOp.getResult (operandIdx).getType ());
1001
- assert (loadVecTy.getRank () == 1 && " Expected a distributed vector" );
1002
992
1003
993
rewriter.setInsertionPointAfter (newWarpOp);
1004
994
xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
0 commit comments