@@ -819,9 +819,15 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
819819 return isa<vector::TransferReadOp>(op) && op->hasOneUse ();
820820 });
821821 if (!operand)
822- return failure ();
822+ return rewriter.notifyMatchFailure (
823+ warpOp, " warp result is not a vector.transfer_read op" );
823824 auto read = operand->get ().getDefiningOp <vector::TransferReadOp>();
824825
826+ // Source must be defined outside of the region.
827+ if (!warpOp.isDefinedOutsideOfRegion (read.getSource ()))
828+ return rewriter.notifyMatchFailure (
829+ read, " source must be defined outside of the region" );
830+
825831 unsigned operandIndex = operand->getOperandNumber ();
826832 Value distributedVal = warpOp.getResult (operandIndex);
827833
@@ -832,10 +838,25 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
832838 AffineMap map = calculateImplicitMap (sequentialType, distributedType);
833839 AffineMap indexMap = map.compose (read.getPermutationMap ());
834840
835- // Distribute the mask if present.
841+ // Try to delinearize the lane ID to match the rank expected for
842+ // distribution.
843+ SmallVector<Value> delinearizedIds;
844+ if (!delinearizeLaneId (rewriter, read.getLoc (), sequentialType.getShape (),
845+ distributedType.getShape (), warpOp.getWarpSize (),
846+ warpOp.getLaneid (), delinearizedIds)) {
847+ return rewriter.notifyMatchFailure (
848+ read, " cannot delinearize lane ID for distribution" );
849+ }
850+ assert (!delinearizedIds.empty () || map.getNumResults () == 0 );
851+
852+ // Distribute indices and the mask (if present).
836853 OpBuilder::InsertionGuard g (rewriter);
837- WarpExecuteOnLane0Op newWarpOp = warpOp;
838- Value newMask = read.getMask ();
854+ SmallVector<Value> additionalResults (indices.begin (), indices.end ());
855+ SmallVector<Type> additionalResultTypes (indices.size (),
856+ rewriter.getIndexType ());
857+ additionalResults.push_back (read.getPadding ());
858+ additionalResultTypes.push_back (read.getPadding ().getType ());
859+
839860 bool hasMask = false ;
840861 if (read.getMask ()) {
841862 hasMask = true ;
@@ -846,42 +867,26 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
846867 // by shape information on the warp op, and thus requires materializing
847868 // the permutation in IR.
848869 if (!mlir::compressUnusedDims (read.getPermutationMap ()).isIdentity ())
849- return failure ();
870+ return rewriter.notifyMatchFailure (
871+ read, " non-trivial permutation maps not supported" );
850872 VectorType maskType =
851873 getDistributedType (read.getMaskType (), map, warpOp.getWarpSize ());
852- SmallVector<size_t > newRetIndices;
853- newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
854- rewriter, warpOp, ValueRange{read.getMask ()}, TypeRange{maskType},
855- newRetIndices);
856- newMask = newWarpOp.getResult (newRetIndices[0 ]);
857- distributedVal = newWarpOp.getResult (operandIndex);
858- } else {
859- // This pattern does not actually change the warp op directly. Instead it
860- // just rewrites a new transfer read (when not masked) outside of the warp
861- // op and replaces the correponding result. There are then follow up
862- // patterns to erase now dead results of the warp op. This erasure allows
863- // propagation to continue, but this pattern on its own never actually
864- // tells the pattern rewriter that the warp op "changed." Notify the
865- // rewriter here that the warp op is changing. Similar situations are
866- // noted in following patterns.
867- rewriter.startRootUpdate (warpOp);
874+ additionalResults.push_back (read.getMask ());
875+ additionalResultTypes.push_back (maskType);
868876 }
869877
870- rewriter.setInsertionPointAfter (newWarpOp);
878+ SmallVector<size_t > newRetIndices;
879+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
880+ rewriter, warpOp, additionalResults, additionalResultTypes,
881+ newRetIndices);
882+ distributedVal = newWarpOp.getResult (operandIndex);
871883
872- // Try to delinearize the lane ID to match the rank expected for
873- // distribution.
874- SmallVector<Value> delinearizedIds;
875- if (!delinearizeLaneId (rewriter, read.getLoc (), sequentialType.getShape (),
876- distributedType.getShape (), newWarpOp.getWarpSize (),
877- newWarpOp.getLaneid (), delinearizedIds)) {
878- if (!hasMask)
879- rewriter.cancelRootUpdate (warpOp);
880- return rewriter.notifyMatchFailure (
881- read, " cannot delinearize lane ID for distribution" );
882- }
883- assert (!delinearizedIds.empty () || map.getNumResults () == 0 );
884+ // Distributed indices were appended first.
885+ SmallVector<Value> newIndices;
886+ for (int64_t i = 0 , e = indices.size (); i < e; ++i)
887+ newIndices.push_back (newWarpOp.getResult (newRetIndices[i]));
884888
889+ rewriter.setInsertionPointAfter (newWarpOp);
885890 for (auto it : llvm::zip_equal (indexMap.getResults (), map.getResults ())) {
886891 AffineExpr d0, d1;
887892 bindDims (read.getContext (), d0, d1);
@@ -891,42 +896,23 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
891896 unsigned indexPos = indexExpr.getPosition ();
892897 unsigned vectorPos = cast<AffineDimExpr>(std::get<1 >(it)).getPosition ();
893898 int64_t scale = distributedType.getDimSize (vectorPos);
894- indices [indexPos] = affine::makeComposedAffineApply (
899+ newIndices [indexPos] = affine::makeComposedAffineApply (
895900 rewriter, read.getLoc (), d0 + scale * d1,
896- {indices [indexPos], delinearizedIds[vectorPos]});
901+ {newIndices [indexPos], delinearizedIds[vectorPos]});
897902 }
903+
904+ // Distributed padding value was appended right after the indices.
905+ Value newPadding = newWarpOp.getResult (newRetIndices[indices.size ()]);
906+ // Distributed mask value was added at the end (if the op has a mask).
907+ Value newMask =
908+ hasMask ? newWarpOp.getResult (newRetIndices[newRetIndices.size () - 1 ])
909+ : Value ();
898910 auto newRead = rewriter.create <vector::TransferReadOp>(
899- read.getLoc (), distributedVal.getType (), read.getSource (), indices ,
900- read.getPermutationMapAttr (), read. getPadding () , newMask,
911+ read.getLoc (), distributedVal.getType (), read.getSource (), newIndices ,
912+ read.getPermutationMapAttr (), newPadding , newMask,
901913 read.getInBoundsAttr ());
902914
903- // Check that the produced operation is legal.
904- // The transfer op may be reading from values that are defined within
905- // warpOp's body, which is illegal.
906- // We do the check late because incdices may be changed by
907- // makeComposeAffineApply. This rewrite may remove dependencies from
908- // warpOp's body.
909- // E.g., warpop {
910- // %idx = affine.apply...[%outsideDef]
911- // ... = transfer_read ...[%idx]
912- // }
913- // will be rewritten in:
914- // warpop {
915- // }
916- // %new_idx = affine.apply...[%outsideDef]
917- // ... = transfer_read ...[%new_idx]
918- if (!llvm::all_of (newRead->getOperands (), [&](Value value) {
919- return (newRead.getMask () && value == newRead.getMask ()) ||
920- newWarpOp.isDefinedOutsideOfRegion (value);
921- })) {
922- if (!hasMask)
923- rewriter.cancelRootUpdate (warpOp);
924- return failure ();
925- }
926-
927915 rewriter.replaceAllUsesWith (distributedVal, newRead);
928- if (!hasMask)
929- rewriter.finalizeRootUpdate (warpOp);
930916 return success ();
931917 }
932918};
@@ -1315,6 +1301,12 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13151301 unsigned int operandNumber = operand->getOperandNumber ();
13161302 auto extractOp = operand->get ().getDefiningOp <vector::ExtractElementOp>();
13171303 VectorType extractSrcType = extractOp.getSourceVectorType ();
1304+ // TODO: Supported shuffle types should be parameterizable, similar to
1305+ // `WarpShuffleFromIdxFn`.
1306+ if (!extractSrcType.getElementType ().isF32 () &&
1307+ !extractSrcType.getElementType ().isInteger (32 ))
1308+ return rewriter.notifyMatchFailure (
1309+ extractOp, " only f32/i32 element types are supported" );
13181310 bool is0dOrVec1Extract = extractSrcType.getNumElements () == 1 ;
13191311 Type elType = extractSrcType.getElementType ();
13201312 VectorType distributedVecType;
0 commit comments