@@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
12291229 VectorType extractSrcType = extractOp.getSourceVectorType ();
12301230 Location loc = extractOp.getLoc ();
12311231
1232- // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1233- assert (extractSrcType.getRank () > 0 &&
1234- " vector.extract does not support rank 0 sources" );
1235-
1236- // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1237- // canonicalized to %v.
1238- if (extractOp.getNumIndices () == 0 )
1232+ // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1233+ if (extractSrcType.getRank () <= 1 ) {
12391234 return failure ();
1240-
1241- // Rewrite vector.extract with 1d source to vector.extractelement.
1242- if (extractSrcType.getRank () == 1 ) {
1243- if (extractOp.hasDynamicPosition ())
1244- // TODO: Dinamic position not supported yet.
1245- return failure ();
1246-
1247- assert (extractOp.getNumIndices () == 1 && " expected 1 index" );
1248- int64_t pos = extractOp.getStaticPosition ()[0 ];
1249- rewriter.setInsertionPoint (extractOp);
1250- rewriter.replaceOpWithNewOp <vector::ExtractElementOp>(
1251- extractOp, extractOp.getVector (),
1252- rewriter.create <arith::ConstantIndexOp>(loc, pos));
1253- return success ();
12541235 }
12551236
12561237 // All following cases are 2d or higher dimensional source vectors.
@@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
13131294 }
13141295};
13151296
1316- // / Pattern to move out vector.extractelement of 0-D tensors. Those don't
1317- // / need to be distributed and can just be propagated outside of the region .
1318- struct WarpOpExtractElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1319- WarpOpExtractElement (MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1320- PatternBenefit b = 1 )
1297+ // / Pattern to move out vector.extract with a scalar result.
1298+ // / Only supports 1-D and 0-D sources for now .
1299+ struct WarpOpExtractScalar : public OpRewritePattern <WarpExecuteOnLane0Op> {
1300+ WarpOpExtractScalar (MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1301+ PatternBenefit b = 1 )
13211302 : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
13221303 warpShuffleFromIdxFn (std::move(fn)) {}
13231304 LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
13241305 PatternRewriter &rewriter) const override {
13251306 OpOperand *operand =
1326- getWarpResult (warpOp, llvm::IsaPred<vector::ExtractElementOp >);
1307+ getWarpResult (warpOp, llvm::IsaPred<vector::ExtractOp >);
13271308 if (!operand)
13281309 return failure ();
13291310 unsigned int operandNumber = operand->getOperandNumber ();
1330- auto extractOp = operand->get ().getDefiningOp <vector::ExtractElementOp >();
1311+ auto extractOp = operand->get ().getDefiningOp <vector::ExtractOp >();
13311312 VectorType extractSrcType = extractOp.getSourceVectorType ();
1313+ // Only supports 1-D or 0-D sources for now.
1314+ if (extractSrcType.getRank () > 1 ) {
1315+ return rewriter.notifyMatchFailure (
1316+ extractOp, " only 0-D or 1-D source supported for now" );
1317+ }
13321318 // TODO: Supported shuffle types should be parameterizable, similar to
13331319 // `WarpShuffleFromIdxFn`.
13341320 if (!extractSrcType.getElementType ().isF32 () &&
@@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13401326 VectorType distributedVecType;
13411327 if (!is0dOrVec1Extract) {
13421328 assert (extractSrcType.getRank () == 1 &&
1343- " expected that extractelement src rank is 0 or 1" );
1329+ " expected that extract src rank is 0 or 1" );
13441330 if (extractSrcType.getShape ()[0 ] % warpOp.getWarpSize () != 0 )
13451331 return failure ();
13461332 int64_t elementsPerLane =
@@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13521338 // Yield source vector and position (if present) from warp op.
13531339 SmallVector<Value> additionalResults{extractOp.getVector ()};
13541340 SmallVector<Type> additionalResultTypes{distributedVecType};
1355- if (static_cast <bool >(extractOp.getPosition ())) {
1356- additionalResults.push_back (extractOp.getPosition ());
1357- additionalResultTypes.push_back (extractOp.getPosition ().getType ());
1358- }
1341+ additionalResults.append (
1342+ SmallVector<Value>(extractOp.getDynamicPosition ()));
1343+ additionalResultTypes.append (
1344+ SmallVector<Type>(extractOp.getDynamicPosition ().getTypes ()));
1345+
13591346 Location loc = extractOp.getLoc ();
13601347 SmallVector<size_t > newRetIndices;
13611348 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
@@ -1368,39 +1355,33 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13681355 // All lanes extract the scalar.
13691356 if (is0dOrVec1Extract) {
13701357 Value newExtract;
1371- if (extractSrcType.getRank () == 1 ) {
1372- newExtract = rewriter.create <vector::ExtractElementOp>(
1373- loc, distributedVec,
1374- rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
1375-
1376- } else {
1377- newExtract =
1378- rewriter.create <vector::ExtractElementOp>(loc, distributedVec);
1379- }
1358+ SmallVector<int64_t > indices (extractSrcType.getRank (), 0 );
1359+ newExtract =
1360+ rewriter.create <vector::ExtractOp>(loc, distributedVec, indices);
13801361 rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber),
13811362 newExtract);
13821363 return success ();
13831364 }
13841365
1366+ int64_t staticPos = extractOp.getStaticPosition ()[0 ];
1367+ OpFoldResult pos = ShapedType::isDynamic (staticPos)
1368+ ? (newWarpOp->getResult (newRetIndices[1 ]))
1369+ : OpFoldResult (rewriter.getIndexAttr (staticPos));
13851370 // 1d extract: Distribute the source vector. One lane extracts and shuffles
13861371 // the value to all other lanes.
13871372 int64_t elementsPerLane = distributedVecType.getShape ()[0 ];
13881373 AffineExpr sym0 = getAffineSymbolExpr (0 , rewriter.getContext ());
13891374 // tid of extracting thread: pos / elementsPerLane
1390- Value broadcastFromTid = rewriter.create <affine::AffineApplyOp>(
1391- loc, sym0.ceilDiv (elementsPerLane),
1392- newWarpOp->getResult (newRetIndices[1 ]));
1375+ Value broadcastFromTid = affine::makeComposedAffineApply (
1376+ rewriter, loc, sym0.ceilDiv (elementsPerLane), pos);
13931377 // Extract at position: pos % elementsPerLane
1394- Value pos =
1378+ Value newPos =
13951379 elementsPerLane == 1
13961380 ? rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ()
1397- : rewriter
1398- .create <affine::AffineApplyOp>(
1399- loc, sym0 % elementsPerLane,
1400- newWarpOp->getResult (newRetIndices[1 ]))
1401- .getResult ();
1381+ : affine::makeComposedAffineApply (rewriter, loc,
1382+ sym0 % elementsPerLane, pos);
14021383 Value extracted =
1403- rewriter.create <vector::ExtractElementOp >(loc, distributedVec, pos );
1384+ rewriter.create <vector::ExtractOp >(loc, distributedVec, newPos );
14041385
14051386 // Shuffle the extracted value to all lanes.
14061387 Value shuffled = warpShuffleFromIdxFn (
@@ -1413,31 +1394,59 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14131394 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
14141395};
14151396
1416- struct WarpOpInsertElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1397+ // / Pattern to convert vector.extractelement to vector.extract.
1398+ struct WarpOpExtractElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1399+ WarpOpExtractElement (MLIRContext *ctx, PatternBenefit b = 1 )
1400+ : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
1401+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1402+ PatternRewriter &rewriter) const override {
1403+ OpOperand *operand =
1404+ getWarpResult (warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1405+ if (!operand)
1406+ return failure ();
1407+ auto extractOp = operand->get ().getDefiningOp <vector::ExtractElementOp>();
1408+ SmallVector<OpFoldResult> indices;
1409+ if (auto pos = extractOp.getPosition ()) {
1410+ indices.push_back (pos);
1411+ }
1412+ rewriter.setInsertionPoint (extractOp);
1413+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(
1414+ extractOp, extractOp.getVector (), indices);
1415+ return success ();
1416+ }
1417+ };
1418+
1419+ // / Pattern to move out vector.insert with a scalar input.
1420+ // / Only supports 1-D and 0-D destinations for now.
1421+ struct WarpOpInsertScalar : public OpRewritePattern <WarpExecuteOnLane0Op> {
14171422 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
14181423
14191424 LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
14201425 PatternRewriter &rewriter) const override {
1421- OpOperand *operand =
1422- getWarpResult (warpOp, llvm::IsaPred<vector::InsertElementOp>);
1426+ OpOperand *operand = getWarpResult (warpOp, llvm::IsaPred<vector::InsertOp>);
14231427 if (!operand)
14241428 return failure ();
14251429 unsigned int operandNumber = operand->getOperandNumber ();
1426- auto insertOp = operand->get ().getDefiningOp <vector::InsertElementOp >();
1430+ auto insertOp = operand->get ().getDefiningOp <vector::InsertOp >();
14271431 VectorType vecType = insertOp.getDestVectorType ();
14281432 VectorType distrType =
14291433 cast<VectorType>(warpOp.getResult (operandNumber).getType ());
1430- bool hasPos = static_cast <bool >(insertOp.getPosition ());
1434+
1435+ // Only supports 1-D or 0-D destinations for now.
1436+ if (vecType.getRank () > 1 ) {
1437+ return rewriter.notifyMatchFailure (
1438+ insertOp, " only 0-D or 1-D source supported for now" );
1439+ }
14311440
14321441 // Yield destination vector, source scalar and position from warp op.
14331442 SmallVector<Value> additionalResults{insertOp.getDest (),
14341443 insertOp.getSource ()};
14351444 SmallVector<Type> additionalResultTypes{distrType,
14361445 insertOp.getSource ().getType ()};
1437- if (hasPos) {
1438- additionalResults. push_back (insertOp. getPosition ());
1439- additionalResultTypes. push_back (insertOp.getPosition ().getType ( ));
1440- }
1446+ additionalResults. append (SmallVector<Value>(insertOp. getDynamicPosition ()));
1447+ additionalResultTypes. append (
1448+ SmallVector<Type> (insertOp.getDynamicPosition ().getTypes () ));
1449+
14411450 Location loc = insertOp.getLoc ();
14421451 SmallVector<size_t > newRetIndices;
14431452 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
@@ -1446,13 +1455,26 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14461455 rewriter.setInsertionPointAfter (newWarpOp);
14471456 Value distributedVec = newWarpOp->getResult (newRetIndices[0 ]);
14481457 Value newSource = newWarpOp->getResult (newRetIndices[1 ]);
1449- Value newPos = hasPos ? newWarpOp->getResult (newRetIndices[2 ]) : Value ();
14501458 rewriter.setInsertionPointAfter (newWarpOp);
14511459
1460+ OpFoldResult pos;
1461+ if (vecType.getRank () != 0 ) {
1462+ int64_t staticPos = insertOp.getStaticPosition ()[0 ];
1463+ pos = ShapedType::isDynamic (staticPos)
1464+ ? (newWarpOp->getResult (newRetIndices[2 ]))
1465+ : OpFoldResult (rewriter.getIndexAttr (staticPos));
1466+ }
1467+
1468+ // This condition is always true for 0-d vectors.
14521469 if (vecType == distrType) {
1453- // Broadcast: Simply move the vector.inserelement op out.
1454- Value newInsert = rewriter.create <vector::InsertElementOp>(
1455- loc, newSource, distributedVec, newPos);
1470+ Value newInsert;
1471+ SmallVector<OpFoldResult> indices;
1472+ if (pos) {
1473+ indices.push_back (pos);
1474+ }
1475+ newInsert = rewriter.create <vector::InsertOp>(loc, newSource,
1476+ distributedVec, indices);
1477+ // Broadcast: Simply move the vector.insert op out.
14561478 rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber),
14571479 newInsert);
14581480 return success ();
@@ -1462,16 +1484,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14621484 int64_t elementsPerLane = distrType.getShape ()[0 ];
14631485 AffineExpr sym0 = getAffineSymbolExpr (0 , rewriter.getContext ());
14641486 // tid of extracting thread: pos / elementsPerLane
1465- Value insertingLane = rewriter. create < affine::AffineApplyOp> (
1466- loc, sym0.ceilDiv (elementsPerLane), newPos );
1487+ Value insertingLane = affine::makeComposedAffineApply (
1488+ rewriter, loc, sym0.ceilDiv (elementsPerLane), pos );
14671489 // Insert position: pos % elementsPerLane
1468- Value pos =
1469- elementsPerLane == 1
1470- ? rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ()
1471- : rewriter
1472- .create <affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1473- newPos)
1474- .getResult ();
1490+ OpFoldResult newPos = affine::makeComposedFoldedAffineApply (
1491+ rewriter, loc, sym0 % elementsPerLane, pos);
14751492 Value isInsertingLane = rewriter.create <arith::CmpIOp>(
14761493 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid (), insertingLane);
14771494 Value newResult =
@@ -1480,8 +1497,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14801497 loc, isInsertingLane,
14811498 /* thenBuilder=*/
14821499 [&](OpBuilder &builder, Location loc) {
1483- Value newInsert = builder.create <vector::InsertElementOp >(
1484- loc, newSource, distributedVec, pos );
1500+ Value newInsert = builder.create <vector::InsertOp >(
1501+ loc, newSource, distributedVec, newPos );
14851502 builder.create <scf::YieldOp>(loc, newInsert);
14861503 },
14871504 /* elseBuilder=*/
@@ -1506,25 +1523,13 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
15061523 auto insertOp = operand->get ().getDefiningOp <vector::InsertOp>();
15071524 Location loc = insertOp.getLoc ();
15081525
1509- // "vector.insert %v, %v[] : ..." can be canonicalized to %v .
1510- if (insertOp.getNumIndices () == 0 )
1526+ // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern .
1527+ if (insertOp.getDestVectorType (). getRank () <= 1 ) {
15111528 return failure ();
1512-
1513- // Rewrite vector.insert with 1d dest to vector.insertelement.
1514- if (insertOp.getDestVectorType ().getRank () == 1 ) {
1515- if (insertOp.hasDynamicPosition ())
1516- // TODO: Dinamic position not supported yet.
1517- return failure ();
1518-
1519- assert (insertOp.getNumIndices () == 1 && " expected 1 index" );
1520- int64_t pos = insertOp.getStaticPosition ()[0 ];
1521- rewriter.setInsertionPoint (insertOp);
1522- rewriter.replaceOpWithNewOp <vector::InsertElementOp>(
1523- insertOp, insertOp.getSource (), insertOp.getDest (),
1524- rewriter.create <arith::ConstantIndexOp>(loc, pos));
1525- return success ();
15261529 }
15271530
1531+ // All following cases are 2d or higher dimensional source vectors.
1532+
15281533 if (warpOp.getResult (operandNumber).getType () == operand->get ().getType ()) {
15291534 // There is no distribution, this is a broadcast. Simply move the insert
15301535 // out of the warp op.
@@ -1620,9 +1625,30 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
16201625 }
16211626};
16221627
1628+ struct WarpOpInsertElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1629+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1630+
1631+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1632+ PatternRewriter &rewriter) const override {
1633+ OpOperand *operand =
1634+ getWarpResult (warpOp, llvm::IsaPred<vector::InsertElementOp>);
1635+ if (!operand)
1636+ return failure ();
1637+ auto insertOp = operand->get ().getDefiningOp <vector::InsertElementOp>();
1638+ SmallVector<OpFoldResult> indices;
1639+ if (auto pos = insertOp.getPosition ()) {
1640+ indices.push_back (pos);
1641+ }
1642+ rewriter.setInsertionPoint (insertOp);
1643+ rewriter.replaceOpWithNewOp <vector::InsertOp>(
1644+ insertOp, insertOp.getSource (), insertOp.getDest (), indices);
1645+ return success ();
1646+ }
1647+ };
1648+
16231649// / Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1624- // / the scf.ForOp is the last operation in the region so that it doesn't change
1625- // / the order of execution. This creates a new scf.for region after the
1650+ // / the scf.ForOp is the last operation in the region so that it doesn't
1651+ // / change the order of execution. This creates a new scf.for region after the
16261652// / WarpExecuteOnLane0Op. The new scf.for region will contain a new
16271653// / WarpExecuteOnLane0Op region. Example:
16281654// / ```
@@ -1668,8 +1694,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
16681694 if (!forOp)
16691695 return failure ();
16701696 // Collect Values that come from the warp op but are outside the forOp.
1671- // Those Value needs to be returned by the original warpOp and passed to the
1672- // new op.
1697+ // Those Value needs to be returned by the original warpOp and passed to
1698+ // the new op.
16731699 llvm::SmallSetVector<Value, 32 > escapingValues;
16741700 SmallVector<Type> inputTypes;
16751701 SmallVector<Type> distTypes;
@@ -1715,8 +1741,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
17151741 OpBuilder::InsertionGuard g (rewriter);
17161742 rewriter.setInsertionPointAfter (newWarpOp);
17171743
1718- // Create a new for op outside the region with a WarpExecuteOnLane0Op region
1719- // inside.
1744+ // Create a new for op outside the region with a WarpExecuteOnLane0Op
1745+ // region inside.
17201746 auto newForOp = rewriter.create <scf::ForOp>(
17211747 forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
17221748 forOp.getStep (), newOperands);
@@ -1778,8 +1804,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
17781804};
17791805
17801806// / A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1781- // / The vector is reduced in parallel. Currently limited to vector size matching
1782- // / the warpOp size. E.g.:
1807+ // / The vector is reduced in parallel. Currently limited to vector size
1808+ // / matching the warpOp size. E.g.:
17831809// / ```
17841810// / %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
17851811// / %0 = "some_def"() : () -> (vector<32xf32>)
@@ -1880,13 +1906,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
18801906 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
18811907 PatternBenefit readBenefit) {
18821908 patterns.add <WarpOpTransferRead>(patterns.getContext (), readBenefit);
1883- patterns
1884- . add <WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast ,
1885- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant ,
1886- WarpOpInsertElement , WarpOpInsert, WarpOpCreateMask>(
1887- patterns.getContext (), benefit);
1888- patterns.add <WarpOpExtractElement >(patterns.getContext (),
1889- warpShuffleFromIdxFn, benefit);
1909+ patterns. add <WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1910+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand ,
1911+ WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement ,
1912+ WarpOpInsertScalar , WarpOpInsert, WarpOpCreateMask>(
1913+ patterns.getContext (), benefit);
1914+ patterns.add <WarpOpExtractScalar >(patterns.getContext (), warpShuffleFromIdxFn ,
1915+ benefit);
18901916 patterns.add <WarpOpScfForOp>(patterns.getContext (), distributionMapFn,
18911917 benefit);
18921918}
0 commit comments