@@ -817,14 +817,14 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
817817 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
818818 if (!storeScatterOp)
819819 return failure ();
820- else if (!storeScatterOp.getOffsets ())
820+ if (!storeScatterOp.getOffsets ())
821821 return rewriter.notifyMatchFailure (storeScatterOp,
822822 " Store op must have offsets argument" );
823- else if (cast<VectorType>(storeScatterOp.getOffsets ().getType ())
824- .getRank () != 1 )
823+ VectorType offsetsTy =
824+ cast<VectorType>(storeScatterOp.getOffsets ().getType ());
825+ if (offsetsTy.getRank () != 1 )
825826 return rewriter.notifyMatchFailure (storeScatterOp,
826827 " Expected 1D offsets vector" );
827-
828828 VectorType storeVecTy =
829829 cast<VectorType>(storeScatterOp.getValue ().getType ());
830830 assert (storeVecTy.getRank () <= 2 &&
@@ -836,33 +836,22 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
836836 distStoreVecTy = VectorType::Builder (storeVecTy).setDim (0 , 1 );
837837
838838 SmallVector<size_t > newRetIndices;
839- SmallVector<Value> operands =
840- llvm::to_vector_of<Value>(storeScatterOp->getOperands ());
839+ SmallVector<Value> operands = storeScatterOp->getOperands ();
841840 SmallVector<Type> operandTypes =
842841 llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes ());
843842 operandTypes[0 ] = distStoreVecTy;
843+ // Assume offset and mask pproducers will be distributed as well.
844+ operandTypes[2 ] = VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
845+ operandTypes[3 ] = VectorType::get (
846+ {1 }, getElementTypeOrSelf (storeScatterOp.getMask ().getType ()));
844847
845848 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
846849 rewriter, warpOp, operands, operandTypes, newRetIndices);
847850 SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
848851 newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
849852
850- Value offsetsVec = newStoreScatterOpOperands[2 ];
851- Value maskVec = newStoreScatterOpOperands[3 ];
852-
853853 auto loc = newWarpOp.getLoc ();
854- Value laneId = warpOp.getLaneid ();
855854 rewriter.setInsertionPointAfter (newWarpOp);
856- Value laneOffset =
857- vector::ExtractOp::create (rewriter, loc, offsetsVec, laneId);
858- laneOffset = vector::BroadcastOp::create (
859- rewriter, loc, VectorType::get ({1 }, laneOffset.getType ()), laneOffset);
860- Value laneMask = vector::ExtractOp::create (rewriter, loc, maskVec, laneId);
861- laneMask = vector::BroadcastOp::create (
862- rewriter, loc, VectorType::get ({1 }, laneMask.getType ()), laneMask);
863- newStoreScatterOpOperands[2 ] = laneOffset;
864- newStoreScatterOpOperands[3 ] = laneMask;
865-
866855 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
867856 rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
868857 storeScatterOp->getAttrs ());
@@ -892,16 +881,20 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
892881 if (!loadGatherOp.getOffsets ())
893882 return rewriter.notifyMatchFailure (loadGatherOp,
894883 " Load op must have offsets argument" );
895- else if (cast<VectorType>(loadGatherOp.getOffsets ().getType ()).getRank () !=
896- 1 )
884+ VectorType offsetsTy =
885+ cast<VectorType>(loadGatherOp.getOffsets ().getType ());
886+ if (offsetsTy.getRank () != 1 )
897887 return rewriter.notifyMatchFailure (loadGatherOp,
898888 " Expected 1D offsets vector" );
899889
900890 SmallVector<size_t > newRetIndices;
901- SmallVector<Value> operands =
902- llvm::to_vector_of<Value>(loadGatherOp->getOperands ());
891+ SmallVector<Value> operands = loadGatherOp->getOperands ();
903892 SmallVector<Type> operandTypes =
904893 llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes ());
894+ // Assume offset and mask pproducers will be distributed as well.
895+ operandTypes[1 ] = VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
896+ operandTypes[2 ] = VectorType::get (
897+ {1 }, getElementTypeOrSelf (loadGatherOp.getMask ().getType ()));
905898
906899 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
907900 rewriter, warpOp, operands, operandTypes, newRetIndices);
@@ -914,21 +907,8 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
914907 cast<VectorType>(warpOp.getResult (operandIdx).getType ());
915908 assert (loadVecTy.getRank () == 1 && " Expected a distributed vector" );
916909
917- Value offsetsVec = newLoadGatherOperands[1 ];
918- Value maskVec = newLoadGatherOperands[2 ];
919910 auto loc = newWarpOp.getLoc ();
920- Value laneId = warpOp.getLaneid ();
921911 rewriter.setInsertionPointAfter (newWarpOp);
922- Value laneOffset =
923- vector::ExtractOp::create (rewriter, loc, offsetsVec, laneId);
924- laneOffset = vector::BroadcastOp::create (
925- rewriter, loc, VectorType::get ({1 }, laneOffset.getType ()), laneOffset);
926- Value laneMask = vector::ExtractOp::create (rewriter, loc, maskVec, laneId);
927- laneMask = vector::BroadcastOp::create (
928- rewriter, loc, VectorType::get ({1 }, laneMask.getType ()), laneMask);
929- newLoadGatherOperands[1 ] = laneOffset;
930- newLoadGatherOperands[2 ] = laneMask;
931-
932912 xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
933913 loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs ());
934914 Value distributedVal = newWarpOp.getResult (operandIdx);
0 commit comments