Skip to content

Commit 6f867b1

Browse files
committed
[mlir][Vector] Move insert/extractelement distribution patterns to insert/extract
1 parent 402efa7 commit 6f867b1

File tree

2 files changed

+149
-118
lines changed

2 files changed

+149
-118
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 137 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
12291229
VectorType extractSrcType = extractOp.getSourceVectorType();
12301230
Location loc = extractOp.getLoc();
12311231

1232-
// "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1233-
assert(extractSrcType.getRank() > 0 &&
1234-
"vector.extract does not support rank 0 sources");
1235-
1236-
// "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1237-
// canonicalized to %v.
1238-
if (extractOp.getNumIndices() == 0)
1232+
// For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1233+
if (extractSrcType.getRank() <= 1) {
12391234
return failure();
1240-
1241-
// Rewrite vector.extract with 1d source to vector.extractelement.
1242-
if (extractSrcType.getRank() == 1) {
1243-
if (extractOp.hasDynamicPosition())
1244-
// TODO: Dinamic position not supported yet.
1245-
return failure();
1246-
1247-
assert(extractOp.getNumIndices() == 1 && "expected 1 index");
1248-
int64_t pos = extractOp.getStaticPosition()[0];
1249-
rewriter.setInsertionPoint(extractOp);
1250-
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
1251-
extractOp, extractOp.getVector(),
1252-
rewriter.create<arith::ConstantIndexOp>(loc, pos));
1253-
return success();
12541235
}
12551236

12561237
// All following cases are 2d or higher dimensional source vectors.
@@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
13131294
}
13141295
};
13151296

1316-
/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
1317-
/// need to be distributed and can just be propagated outside of the region.
1318-
struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1319-
WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1320-
PatternBenefit b = 1)
1297+
/// Pattern to move out vector.extract with a scalar result.
1298+
/// Only supports 1-D and 0-D sources for now.
1299+
struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
1300+
WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1301+
PatternBenefit b = 1)
13211302
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
13221303
warpShuffleFromIdxFn(std::move(fn)) {}
13231304
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
13241305
PatternRewriter &rewriter) const override {
13251306
OpOperand *operand =
1326-
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1307+
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
13271308
if (!operand)
13281309
return failure();
13291310
unsigned int operandNumber = operand->getOperandNumber();
1330-
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1311+
auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
13311312
VectorType extractSrcType = extractOp.getSourceVectorType();
1313+
// Only supports 1-D or 0-D sources for now.
1314+
if (extractSrcType.getRank() > 1) {
1315+
return rewriter.notifyMatchFailure(
1316+
extractOp, "only 0-D or 1-D source supported for now");
1317+
}
13321318
// TODO: Supported shuffle types should be parameterizable, similar to
13331319
// `WarpShuffleFromIdxFn`.
13341320
if (!extractSrcType.getElementType().isF32() &&
@@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13401326
VectorType distributedVecType;
13411327
if (!is0dOrVec1Extract) {
13421328
assert(extractSrcType.getRank() == 1 &&
1343-
"expected that extractelement src rank is 0 or 1");
1329+
"expected that extract src rank is 0 or 1");
13441330
if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
13451331
return failure();
13461332
int64_t elementsPerLane =
@@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13521338
// Yield source vector and position (if present) from warp op.
13531339
SmallVector<Value> additionalResults{extractOp.getVector()};
13541340
SmallVector<Type> additionalResultTypes{distributedVecType};
1355-
if (static_cast<bool>(extractOp.getPosition())) {
1356-
additionalResults.push_back(extractOp.getPosition());
1357-
additionalResultTypes.push_back(extractOp.getPosition().getType());
1358-
}
1341+
additionalResults.append(
1342+
SmallVector<Value>(extractOp.getDynamicPosition()));
1343+
additionalResultTypes.append(
1344+
SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1345+
13591346
Location loc = extractOp.getLoc();
13601347
SmallVector<size_t> newRetIndices;
13611348
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1369,38 +1356,35 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13691356
if (is0dOrVec1Extract) {
13701357
Value newExtract;
13711358
if (extractSrcType.getRank() == 1) {
1372-
newExtract = rewriter.create<vector::ExtractElementOp>(
1373-
loc, distributedVec,
1374-
rewriter.create<arith::ConstantIndexOp>(loc, 0));
1375-
1359+
newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec, 0);
13761360
} else {
1377-
newExtract =
1378-
rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
1361+
newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec,
1362+
ArrayRef<int64_t>{});
13791363
}
13801364
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
13811365
newExtract);
13821366
return success();
13831367
}
13841368

1369+
int64_t staticPos = extractOp.getStaticPosition()[0];
1370+
OpFoldResult pos = ShapedType::isDynamic(staticPos)
1371+
? (newWarpOp->getResult(newRetIndices[1]))
1372+
: OpFoldResult(rewriter.getIndexAttr(staticPos));
13851373
// 1d extract: Distribute the source vector. One lane extracts and shuffles
13861374
// the value to all other lanes.
13871375
int64_t elementsPerLane = distributedVecType.getShape()[0];
13881376
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
13891377
// tid of extracting thread: pos / elementsPerLane
1390-
Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
1391-
loc, sym0.ceilDiv(elementsPerLane),
1392-
newWarpOp->getResult(newRetIndices[1]));
1378+
Value broadcastFromTid = affine::makeComposedAffineApply(
1379+
rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
13931380
// Extract at position: pos % elementsPerLane
1394-
Value pos =
1381+
Value newPos =
13951382
elementsPerLane == 1
13961383
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1397-
: rewriter
1398-
.create<affine::AffineApplyOp>(
1399-
loc, sym0 % elementsPerLane,
1400-
newWarpOp->getResult(newRetIndices[1]))
1401-
.getResult();
1384+
: affine::makeComposedAffineApply(rewriter, loc,
1385+
sym0 % elementsPerLane, pos);
14021386
Value extracted =
1403-
rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
1387+
rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
14041388

14051389
// Shuffle the extracted value to all lanes.
14061390
Value shuffled = warpShuffleFromIdxFn(
@@ -1413,31 +1397,60 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14131397
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
14141398
};
14151399

1416-
struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1400+
/// Pattern to convert vector.extractelement to vector.extract.
1401+
struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1402+
WarpOpExtractElement(MLIRContext *ctx, PatternBenefit b = 1)
1403+
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
1404+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1405+
PatternRewriter &rewriter) const override {
1406+
OpOperand *operand =
1407+
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1408+
if (!operand)
1409+
return failure();
1410+
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1411+
rewriter.setInsertionPoint(extractOp);
1412+
if (auto pos = extractOp.getPosition()) {
1413+
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1414+
extractOp, extractOp.getVector(), pos);
1415+
} else {
1416+
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1417+
extractOp, extractOp.getVector(), ArrayRef<int64_t>{});
1418+
}
1419+
return success();
1420+
}
1421+
};
1422+
1423+
/// Pattern to move out vector.insert with a scalar input.
1424+
/// Only supports 1-D and 0-D destinations for now.
1425+
struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
14171426
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
14181427

14191428
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
14201429
PatternRewriter &rewriter) const override {
1421-
OpOperand *operand =
1422-
getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1430+
OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
14231431
if (!operand)
14241432
return failure();
14251433
unsigned int operandNumber = operand->getOperandNumber();
1426-
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1434+
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
14271435
VectorType vecType = insertOp.getDestVectorType();
14281436
VectorType distrType =
14291437
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1430-
bool hasPos = static_cast<bool>(insertOp.getPosition());
1438+
1439+
// Only supports 1-D or 0-D destinations for now.
1440+
if (vecType.getRank() > 1) {
1441+
return rewriter.notifyMatchFailure(
1442+
insertOp, "only 0-D or 1-D source supported for now");
1443+
}
14311444

14321445
// Yield destination vector, source scalar and position from warp op.
14331446
SmallVector<Value> additionalResults{insertOp.getDest(),
14341447
insertOp.getSource()};
14351448
SmallVector<Type> additionalResultTypes{distrType,
14361449
insertOp.getSource().getType()};
1437-
if (hasPos) {
1438-
additionalResults.push_back(insertOp.getPosition());
1439-
additionalResultTypes.push_back(insertOp.getPosition().getType());
1440-
}
1450+
additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1451+
additionalResultTypes.append(
1452+
SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1453+
14411454
Location loc = insertOp.getLoc();
14421455
SmallVector<size_t> newRetIndices;
14431456
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1446,13 +1459,27 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14461459
rewriter.setInsertionPointAfter(newWarpOp);
14471460
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
14481461
Value newSource = newWarpOp->getResult(newRetIndices[1]);
1449-
Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
14501462
rewriter.setInsertionPointAfter(newWarpOp);
14511463

1464+
OpFoldResult pos;
1465+
if (vecType.getRank() != 0) {
1466+
int64_t staticPos = insertOp.getStaticPosition()[0];
1467+
pos = ShapedType::isDynamic(staticPos)
1468+
? (newWarpOp->getResult(newRetIndices[2]))
1469+
: OpFoldResult(rewriter.getIndexAttr(staticPos));
1470+
}
1471+
1472+
// This condition is always true for 0-d vectors.
14521473
if (vecType == distrType) {
1453-
// Broadcast: Simply move the vector.inserelement op out.
1454-
Value newInsert = rewriter.create<vector::InsertElementOp>(
1455-
loc, newSource, distributedVec, newPos);
1474+
Value newInsert;
1475+
if (pos) {
1476+
newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
1477+
distributedVec, pos);
1478+
} else {
1479+
newInsert = rewriter.create<vector::InsertOp>(
1480+
loc, newSource, distributedVec, ArrayRef<int64_t>{});
1481+
}
1482+
// Broadcast: Simply move the vector.insert op out.
14561483
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
14571484
newInsert);
14581485
return success();
@@ -1462,16 +1489,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14621489
int64_t elementsPerLane = distrType.getShape()[0];
14631490
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
14641491
// tid of extracting thread: pos / elementsPerLane
1465-
Value insertingLane = rewriter.create<affine::AffineApplyOp>(
1466-
loc, sym0.ceilDiv(elementsPerLane), newPos);
1492+
Value insertingLane = affine::makeComposedAffineApply(
1493+
rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
14671494
// Insert position: pos % elementsPerLane
1468-
Value pos =
1469-
elementsPerLane == 1
1470-
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1471-
: rewriter
1472-
.create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1473-
newPos)
1474-
.getResult();
1495+
OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1496+
rewriter, loc, sym0 % elementsPerLane, pos);
14751497
Value isInsertingLane = rewriter.create<arith::CmpIOp>(
14761498
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
14771499
Value newResult =
@@ -1480,8 +1502,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14801502
loc, isInsertingLane,
14811503
/*thenBuilder=*/
14821504
[&](OpBuilder &builder, Location loc) {
1483-
Value newInsert = builder.create<vector::InsertElementOp>(
1484-
loc, newSource, distributedVec, pos);
1505+
Value newInsert = builder.create<vector::InsertOp>(
1506+
loc, newSource, distributedVec, newPos);
14851507
builder.create<scf::YieldOp>(loc, newInsert);
14861508
},
14871509
/*elseBuilder=*/
@@ -1506,25 +1528,13 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
15061528
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
15071529
Location loc = insertOp.getLoc();
15081530

1509-
// "vector.insert %v, %v[] : ..." can be canonicalized to %v.
1510-
if (insertOp.getNumIndices() == 0)
1531+
// For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1532+
if (insertOp.getDestVectorType().getRank() <= 1) {
15111533
return failure();
1512-
1513-
// Rewrite vector.insert with 1d dest to vector.insertelement.
1514-
if (insertOp.getDestVectorType().getRank() == 1) {
1515-
if (insertOp.hasDynamicPosition())
1516-
// TODO: Dinamic position not supported yet.
1517-
return failure();
1518-
1519-
assert(insertOp.getNumIndices() == 1 && "expected 1 index");
1520-
int64_t pos = insertOp.getStaticPosition()[0];
1521-
rewriter.setInsertionPoint(insertOp);
1522-
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
1523-
insertOp, insertOp.getSource(), insertOp.getDest(),
1524-
rewriter.create<arith::ConstantIndexOp>(loc, pos));
1525-
return success();
15261534
}
15271535

1536+
// All following cases are 2d or higher dimensional source vectors.
1537+
15281538
if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
15291539
// There is no distribution, this is a broadcast. Simply move the insert
15301540
// out of the warp op.
@@ -1620,9 +1630,32 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
16201630
}
16211631
};
16221632

1633+
struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1634+
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1635+
1636+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1637+
PatternRewriter &rewriter) const override {
1638+
OpOperand *operand =
1639+
getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1640+
if (!operand)
1641+
return failure();
1642+
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1643+
rewriter.setInsertionPoint(insertOp);
1644+
if (auto pos = insertOp.getPosition()) {
1645+
rewriter.replaceOpWithNewOp<vector::InsertOp>(
1646+
insertOp, insertOp.getSource(), insertOp.getDest(), pos);
1647+
} else {
1648+
rewriter.replaceOpWithNewOp<vector::InsertOp>(
1649+
insertOp, insertOp.getSource(), insertOp.getDest(),
1650+
ArrayRef<int64_t>{});
1651+
}
1652+
return success();
1653+
}
1654+
};
1655+
16231656
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1624-
/// the scf.ForOp is the last operation in the region so that it doesn't change
1625-
/// the order of execution. This creates a new scf.for region after the
1657+
/// the scf.ForOp is the last operation in the region so that it doesn't
1658+
/// change the order of execution. This creates a new scf.for region after the
16261659
/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
16271660
/// WarpExecuteOnLane0Op region. Example:
16281661
/// ```
@@ -1668,8 +1701,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
16681701
if (!forOp)
16691702
return failure();
16701703
// Collect Values that come from the warp op but are outside the forOp.
1671-
// Those Value needs to be returned by the original warpOp and passed to the
1672-
// new op.
1704+
// Those Value needs to be returned by the original warpOp and passed to
1705+
// the new op.
16731706
llvm::SmallSetVector<Value, 32> escapingValues;
16741707
SmallVector<Type> inputTypes;
16751708
SmallVector<Type> distTypes;
@@ -1715,8 +1748,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
17151748
OpBuilder::InsertionGuard g(rewriter);
17161749
rewriter.setInsertionPointAfter(newWarpOp);
17171750

1718-
// Create a new for op outside the region with a WarpExecuteOnLane0Op region
1719-
// inside.
1751+
// Create a new for op outside the region with a WarpExecuteOnLane0Op
1752+
// region inside.
17201753
auto newForOp = rewriter.create<scf::ForOp>(
17211754
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
17221755
forOp.getStep(), newOperands);
@@ -1778,8 +1811,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
17781811
};
17791812

17801813
/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1781-
/// The vector is reduced in parallel. Currently limited to vector size matching
1782-
/// the warpOp size. E.g.:
1814+
/// The vector is reduced in parallel. Currently limited to vector size
1815+
/// matching the warpOp size. E.g.:
17831816
/// ```
17841817
/// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
17851818
/// %0 = "some_def"() : () -> (vector<32xf32>)
@@ -1880,13 +1913,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
18801913
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
18811914
PatternBenefit readBenefit) {
18821915
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1883-
patterns
1884-
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1885-
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1886-
WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1887-
patterns.getContext(), benefit);
1888-
patterns.add<WarpOpExtractElement>(patterns.getContext(),
1889-
warpShuffleFromIdxFn, benefit);
1916+
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1917+
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1918+
WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1919+
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1920+
patterns.getContext(), benefit);
1921+
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
1922+
benefit);
18901923
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
18911924
benefit);
18921925
}

0 commit comments

Comments
 (0)