@@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
12291229 VectorType extractSrcType = extractOp.getSourceVectorType ();
12301230 Location loc = extractOp.getLoc ();
12311231
1232- // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1233- assert (extractSrcType.getRank () > 0 &&
1234- " vector.extract does not support rank 0 sources" );
1235-
1236- // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1237- // canonicalized to %v.
1238- if (extractOp.getNumIndices () == 0 )
1232+ // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1233+ if (extractSrcType.getRank () <= 1 ) {
12391234 return failure ();
1240-
1241- // Rewrite vector.extract with 1d source to vector.extractelement.
1242- if (extractSrcType.getRank () == 1 ) {
1243- if (extractOp.hasDynamicPosition ())
1244- // TODO: Dinamic position not supported yet.
1245- return failure ();
1246-
1247- assert (extractOp.getNumIndices () == 1 && " expected 1 index" );
1248- int64_t pos = extractOp.getStaticPosition ()[0 ];
1249- rewriter.setInsertionPoint (extractOp);
1250- rewriter.replaceOpWithNewOp <vector::ExtractElementOp>(
1251- extractOp, extractOp.getVector (),
1252- rewriter.create <arith::ConstantIndexOp>(loc, pos));
1253- return success ();
12541235 }
12551236
12561237 // All following cases are 2d or higher dimensional source vectors.
@@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
13131294 }
13141295};
13151296
1316- // / Pattern to move out vector.extractelement of 0-D tensors. Those don't
1317- // / need to be distributed and can just be propagated outside of the region .
1318- struct WarpOpExtractElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1319- WarpOpExtractElement (MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1320- PatternBenefit b = 1 )
1297+ // / Pattern to move out vector.extract with a scalar result.
1298+ // / Only supports 1-D and 0-D sources for now .
1299+ struct WarpOpExtractScalar : public OpRewritePattern <WarpExecuteOnLane0Op> {
1300+ WarpOpExtractScalar (MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1301+ PatternBenefit b = 1 )
13211302 : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
13221303 warpShuffleFromIdxFn (std::move(fn)) {}
13231304 LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
13241305 PatternRewriter &rewriter) const override {
13251306 OpOperand *operand =
1326- getWarpResult (warpOp, llvm::IsaPred<vector::ExtractElementOp >);
1307+ getWarpResult (warpOp, llvm::IsaPred<vector::ExtractOp >);
13271308 if (!operand)
13281309 return failure ();
13291310 unsigned int operandNumber = operand->getOperandNumber ();
1330- auto extractOp = operand->get ().getDefiningOp <vector::ExtractElementOp >();
1311+ auto extractOp = operand->get ().getDefiningOp <vector::ExtractOp >();
13311312 VectorType extractSrcType = extractOp.getSourceVectorType ();
1313+ // Only supports 1-D or 0-D sources for now.
1314+ if (extractSrcType.getRank () > 1 ) {
1315+ return rewriter.notifyMatchFailure (
1316+ extractOp, " only 0-D or 1-D source supported for now" );
1317+ }
13321318 // TODO: Supported shuffle types should be parameterizable, similar to
13331319 // `WarpShuffleFromIdxFn`.
13341320 if (!extractSrcType.getElementType ().isF32 () &&
@@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13401326 VectorType distributedVecType;
13411327 if (!is0dOrVec1Extract) {
13421328 assert (extractSrcType.getRank () == 1 &&
1343- " expected that extractelement src rank is 0 or 1" );
1329+ " expected that extract src rank is 0 or 1" );
13441330 if (extractSrcType.getShape ()[0 ] % warpOp.getWarpSize () != 0 )
13451331 return failure ();
13461332 int64_t elementsPerLane =
@@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13521338 // Yield source vector and position (if present) from warp op.
13531339 SmallVector<Value> additionalResults{extractOp.getVector ()};
13541340 SmallVector<Type> additionalResultTypes{distributedVecType};
1355- if (static_cast <bool >(extractOp.getPosition ())) {
1356- additionalResults.push_back (extractOp.getPosition ());
1357- additionalResultTypes.push_back (extractOp.getPosition ().getType ());
1358- }
1341+ additionalResults.append (
1342+ SmallVector<Value>(extractOp.getDynamicPosition ()));
1343+ additionalResultTypes.append (
1344+ SmallVector<Type>(extractOp.getDynamicPosition ().getTypes ()));
1345+
13591346 Location loc = extractOp.getLoc ();
13601347 SmallVector<size_t > newRetIndices;
13611348 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
@@ -1369,38 +1356,35 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13691356 if (is0dOrVec1Extract) {
13701357 Value newExtract;
13711358 if (extractSrcType.getRank () == 1 ) {
1372- newExtract = rewriter.create <vector::ExtractElementOp>(
1373- loc, distributedVec,
1374- rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
1375-
1359+ newExtract = rewriter.create <vector::ExtractOp>(loc, distributedVec, 0 );
13761360 } else {
1377- newExtract =
1378- rewriter. create <vector::ExtractElementOp>(loc, distributedVec );
1361+ newExtract = rewriter. create <vector::ExtractOp>(loc, distributedVec,
1362+ ArrayRef< int64_t >{} );
13791363 }
13801364 rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber),
13811365 newExtract);
13821366 return success ();
13831367 }
13841368
1369+ int64_t staticPos = extractOp.getStaticPosition ()[0 ];
1370+ OpFoldResult pos = ShapedType::isDynamic (staticPos)
1371+ ? (newWarpOp->getResult (newRetIndices[1 ]))
1372+ : OpFoldResult (rewriter.getIndexAttr (staticPos));
13851373 // 1d extract: Distribute the source vector. One lane extracts and shuffles
13861374 // the value to all other lanes.
13871375 int64_t elementsPerLane = distributedVecType.getShape ()[0 ];
13881376 AffineExpr sym0 = getAffineSymbolExpr (0 , rewriter.getContext ());
13891377 // tid of extracting thread: pos / elementsPerLane
1390- Value broadcastFromTid = rewriter.create <affine::AffineApplyOp>(
1391- loc, sym0.ceilDiv (elementsPerLane),
1392- newWarpOp->getResult (newRetIndices[1 ]));
1378+ Value broadcastFromTid = affine::makeComposedAffineApply (
1379+ rewriter, loc, sym0.ceilDiv (elementsPerLane), pos);
13931380 // Extract at position: pos % elementsPerLane
1394- Value pos =
1381+ Value newPos =
13951382 elementsPerLane == 1
13961383 ? rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ()
1397- : rewriter
1398- .create <affine::AffineApplyOp>(
1399- loc, sym0 % elementsPerLane,
1400- newWarpOp->getResult (newRetIndices[1 ]))
1401- .getResult ();
1384+ : affine::makeComposedAffineApply (rewriter, loc,
1385+ sym0 % elementsPerLane, pos);
14021386 Value extracted =
1403- rewriter.create <vector::ExtractElementOp >(loc, distributedVec, pos );
1387+ rewriter.create <vector::ExtractOp >(loc, distributedVec, newPos );
14041388
14051389 // Shuffle the extracted value to all lanes.
14061390 Value shuffled = warpShuffleFromIdxFn (
@@ -1413,31 +1397,60 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14131397 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
14141398};
14151399
1416- struct WarpOpInsertElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1400+ // / Pattern to convert vector.extractelement to vector.extract.
1401+ struct WarpOpExtractElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1402+ WarpOpExtractElement (MLIRContext *ctx, PatternBenefit b = 1 )
1403+ : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
1404+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1405+ PatternRewriter &rewriter) const override {
1406+ OpOperand *operand =
1407+ getWarpResult (warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1408+ if (!operand)
1409+ return failure ();
1410+ auto extractOp = operand->get ().getDefiningOp <vector::ExtractElementOp>();
1411+ rewriter.setInsertionPoint (extractOp);
1412+ if (auto pos = extractOp.getPosition ()) {
1413+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(
1414+ extractOp, extractOp.getVector (), pos);
1415+ } else {
1416+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(
1417+ extractOp, extractOp.getVector (), ArrayRef<int64_t >{});
1418+ }
1419+ return success ();
1420+ }
1421+ };
1422+
1423+ // / Pattern to move out vector.insert with a scalar input.
1424+ // / Only supports 1-D and 0-D destinations for now.
1425+ struct WarpOpInsertScalar : public OpRewritePattern <WarpExecuteOnLane0Op> {
14171426 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
14181427
14191428 LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
14201429 PatternRewriter &rewriter) const override {
1421- OpOperand *operand =
1422- getWarpResult (warpOp, llvm::IsaPred<vector::InsertElementOp>);
1430+ OpOperand *operand = getWarpResult (warpOp, llvm::IsaPred<vector::InsertOp>);
14231431 if (!operand)
14241432 return failure ();
14251433 unsigned int operandNumber = operand->getOperandNumber ();
1426- auto insertOp = operand->get ().getDefiningOp <vector::InsertElementOp >();
1434+ auto insertOp = operand->get ().getDefiningOp <vector::InsertOp >();
14271435 VectorType vecType = insertOp.getDestVectorType ();
14281436 VectorType distrType =
14291437 cast<VectorType>(warpOp.getResult (operandNumber).getType ());
1430- bool hasPos = static_cast <bool >(insertOp.getPosition ());
1438+
1439+ // Only supports 1-D or 0-D destinations for now.
1440+ if (vecType.getRank () > 1 ) {
1441+ return rewriter.notifyMatchFailure (
1442+ insertOp, " only 0-D or 1-D source supported for now" );
1443+ }
14311444
14321445 // Yield destination vector, source scalar and position from warp op.
14331446 SmallVector<Value> additionalResults{insertOp.getDest (),
14341447 insertOp.getSource ()};
14351448 SmallVector<Type> additionalResultTypes{distrType,
14361449 insertOp.getSource ().getType ()};
1437- if (hasPos) {
1438- additionalResults. push_back (insertOp. getPosition ());
1439- additionalResultTypes. push_back (insertOp.getPosition ().getType ( ));
1440- }
1450+ additionalResults. append (SmallVector<Value>(insertOp. getDynamicPosition ()));
1451+ additionalResultTypes. append (
1452+ SmallVector<Type> (insertOp.getDynamicPosition ().getTypes () ));
1453+
14411454 Location loc = insertOp.getLoc ();
14421455 SmallVector<size_t > newRetIndices;
14431456 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
@@ -1446,13 +1459,27 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14461459 rewriter.setInsertionPointAfter (newWarpOp);
14471460 Value distributedVec = newWarpOp->getResult (newRetIndices[0 ]);
14481461 Value newSource = newWarpOp->getResult (newRetIndices[1 ]);
1449- Value newPos = hasPos ? newWarpOp->getResult (newRetIndices[2 ]) : Value ();
14501462 rewriter.setInsertionPointAfter (newWarpOp);
14511463
1464+ OpFoldResult pos;
1465+ if (vecType.getRank () != 0 ) {
1466+ int64_t staticPos = insertOp.getStaticPosition ()[0 ];
1467+ pos = ShapedType::isDynamic (staticPos)
1468+ ? (newWarpOp->getResult (newRetIndices[2 ]))
1469+ : OpFoldResult (rewriter.getIndexAttr (staticPos));
1470+ }
1471+
1472+ // This condition is always true for 0-d vectors.
14521473 if (vecType == distrType) {
1453- // Broadcast: Simply move the vector.inserelement op out.
1454- Value newInsert = rewriter.create <vector::InsertElementOp>(
1455- loc, newSource, distributedVec, newPos);
1474+ Value newInsert;
1475+ if (pos) {
1476+ newInsert = rewriter.create <vector::InsertOp>(loc, newSource,
1477+ distributedVec, pos);
1478+ } else {
1479+ newInsert = rewriter.create <vector::InsertOp>(
1480+ loc, newSource, distributedVec, ArrayRef<int64_t >{});
1481+ }
1482+ // Broadcast: Simply move the vector.insert op out.
14561483 rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber),
14571484 newInsert);
14581485 return success ();
@@ -1462,16 +1489,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14621489 int64_t elementsPerLane = distrType.getShape ()[0 ];
14631490 AffineExpr sym0 = getAffineSymbolExpr (0 , rewriter.getContext ());
14641491 // tid of extracting thread: pos / elementsPerLane
1465- Value insertingLane = rewriter. create < affine::AffineApplyOp> (
1466- loc, sym0.ceilDiv (elementsPerLane), newPos );
1492+ Value insertingLane = affine::makeComposedAffineApply (
1493+ rewriter, loc, sym0.ceilDiv (elementsPerLane), pos );
14671494 // Insert position: pos % elementsPerLane
1468- Value pos =
1469- elementsPerLane == 1
1470- ? rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ()
1471- : rewriter
1472- .create <affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1473- newPos)
1474- .getResult ();
1495+ OpFoldResult newPos = affine::makeComposedFoldedAffineApply (
1496+ rewriter, loc, sym0 % elementsPerLane, pos);
14751497 Value isInsertingLane = rewriter.create <arith::CmpIOp>(
14761498 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid (), insertingLane);
14771499 Value newResult =
@@ -1480,8 +1502,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14801502 loc, isInsertingLane,
14811503 /* thenBuilder=*/
14821504 [&](OpBuilder &builder, Location loc) {
1483- Value newInsert = builder.create <vector::InsertElementOp >(
1484- loc, newSource, distributedVec, pos );
1505+ Value newInsert = builder.create <vector::InsertOp >(
1506+ loc, newSource, distributedVec, newPos );
14851507 builder.create <scf::YieldOp>(loc, newInsert);
14861508 },
14871509 /* elseBuilder=*/
@@ -1506,25 +1528,13 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
15061528 auto insertOp = operand->get ().getDefiningOp <vector::InsertOp>();
15071529 Location loc = insertOp.getLoc ();
15081530
1509- // "vector.insert %v, %v[] : ..." can be canonicalized to %v .
1510- if (insertOp.getNumIndices () == 0 )
1531+ // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern .
1532+ if (insertOp.getDestVectorType (). getRank () <= 1 ) {
15111533 return failure ();
1512-
1513- // Rewrite vector.insert with 1d dest to vector.insertelement.
1514- if (insertOp.getDestVectorType ().getRank () == 1 ) {
1515- if (insertOp.hasDynamicPosition ())
1516- // TODO: Dinamic position not supported yet.
1517- return failure ();
1518-
1519- assert (insertOp.getNumIndices () == 1 && " expected 1 index" );
1520- int64_t pos = insertOp.getStaticPosition ()[0 ];
1521- rewriter.setInsertionPoint (insertOp);
1522- rewriter.replaceOpWithNewOp <vector::InsertElementOp>(
1523- insertOp, insertOp.getSource (), insertOp.getDest (),
1524- rewriter.create <arith::ConstantIndexOp>(loc, pos));
1525- return success ();
15261534 }
15271535
1536+ // All following cases are 2d or higher dimensional source vectors.
1537+
15281538 if (warpOp.getResult (operandNumber).getType () == operand->get ().getType ()) {
15291539 // There is no distribution, this is a broadcast. Simply move the insert
15301540 // out of the warp op.
@@ -1620,9 +1630,32 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
16201630 }
16211631};
16221632
1633+ struct WarpOpInsertElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1634+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1635+
1636+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1637+ PatternRewriter &rewriter) const override {
1638+ OpOperand *operand =
1639+ getWarpResult (warpOp, llvm::IsaPred<vector::InsertElementOp>);
1640+ if (!operand)
1641+ return failure ();
1642+ auto insertOp = operand->get ().getDefiningOp <vector::InsertElementOp>();
1643+ rewriter.setInsertionPoint (insertOp);
1644+ if (auto pos = insertOp.getPosition ()) {
1645+ rewriter.replaceOpWithNewOp <vector::InsertOp>(
1646+ insertOp, insertOp.getSource (), insertOp.getDest (), pos);
1647+ } else {
1648+ rewriter.replaceOpWithNewOp <vector::InsertOp>(
1649+ insertOp, insertOp.getSource (), insertOp.getDest (),
1650+ ArrayRef<int64_t >{});
1651+ }
1652+ return success ();
1653+ }
1654+ };
1655+
16231656// / Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1624- // / the scf.ForOp is the last operation in the region so that it doesn't change
1625- // / the order of execution. This creates a new scf.for region after the
1657+ // / the scf.ForOp is the last operation in the region so that it doesn't
1658+ // / change the order of execution. This creates a new scf.for region after the
16261659// / WarpExecuteOnLane0Op. The new scf.for region will contain a new
16271660// / WarpExecuteOnLane0Op region. Example:
16281661// / ```
@@ -1668,8 +1701,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
16681701 if (!forOp)
16691702 return failure ();
16701703 // Collect Values that come from the warp op but are outside the forOp.
1671- // Those Value needs to be returned by the original warpOp and passed to the
1672- // new op.
1704+ // Those Value needs to be returned by the original warpOp and passed to
1705+ // the new op.
16731706 llvm::SmallSetVector<Value, 32 > escapingValues;
16741707 SmallVector<Type> inputTypes;
16751708 SmallVector<Type> distTypes;
@@ -1715,8 +1748,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
17151748 OpBuilder::InsertionGuard g (rewriter);
17161749 rewriter.setInsertionPointAfter (newWarpOp);
17171750
1718- // Create a new for op outside the region with a WarpExecuteOnLane0Op region
1719- // inside.
1751+ // Create a new for op outside the region with a WarpExecuteOnLane0Op
1752+ // region inside.
17201753 auto newForOp = rewriter.create <scf::ForOp>(
17211754 forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
17221755 forOp.getStep (), newOperands);
@@ -1778,8 +1811,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
17781811};
17791812
17801813// / A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1781- // / The vector is reduced in parallel. Currently limited to vector size matching
1782- // / the warpOp size. E.g.:
1814+ // / The vector is reduced in parallel. Currently limited to vector size
1815+ // / matching the warpOp size. E.g.:
17831816// / ```
17841817// / %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
17851818// / %0 = "some_def"() : () -> (vector<32xf32>)
@@ -1880,13 +1913,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
18801913 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
18811914 PatternBenefit readBenefit) {
18821915 patterns.add <WarpOpTransferRead>(patterns.getContext (), readBenefit);
1883- patterns
1884- . add <WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast ,
1885- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant ,
1886- WarpOpInsertElement , WarpOpInsert, WarpOpCreateMask>(
1887- patterns.getContext (), benefit);
1888- patterns.add <WarpOpExtractElement >(patterns.getContext (),
1889- warpShuffleFromIdxFn, benefit);
1916+ patterns. add <WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1917+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand ,
1918+ WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement ,
1919+ WarpOpInsertScalar , WarpOpInsert, WarpOpCreateMask>(
1920+ patterns.getContext (), benefit);
1921+ patterns.add <WarpOpExtractScalar >(patterns.getContext (), warpShuffleFromIdxFn ,
1922+ benefit);
18901923 patterns.add <WarpOpScfForOp>(patterns.getContext (), distributionMapFn,
18911924 benefit);
18921925}
0 commit comments