Skip to content

Commit 823909c

Browse files
committed
Address comments
1 parent 6f867b1 commit 823909c

File tree

1 file changed

+17
-24
lines changed

1 file changed

+17
-24
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,12 +1355,9 @@ struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
13551355
// All lanes extract the scalar.
13561356
if (is0dOrVec1Extract) {
13571357
Value newExtract;
1358-
if (extractSrcType.getRank() == 1) {
1359-
newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec, 0);
1360-
} else {
1361-
newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec,
1362-
ArrayRef<int64_t>{});
1363-
}
1358+
SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1359+
newExtract =
1360+
rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
13641361
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
13651362
newExtract);
13661363
return success();
@@ -1408,14 +1405,13 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14081405
if (!operand)
14091406
return failure();
14101407
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1411-
rewriter.setInsertionPoint(extractOp);
1408+
SmallVector<OpFoldResult> indices;
14121409
if (auto pos = extractOp.getPosition()) {
1413-
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1414-
extractOp, extractOp.getVector(), pos);
1415-
} else {
1416-
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1417-
extractOp, extractOp.getVector(), ArrayRef<int64_t>{});
1410+
indices.push_back(pos);
14181411
}
1412+
rewriter.setInsertionPoint(extractOp);
1413+
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1414+
extractOp, extractOp.getVector(), indices);
14191415
return success();
14201416
}
14211417
};
@@ -1472,13 +1468,12 @@ struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
14721468
// This condition is always true for 0-d vectors.
14731469
if (vecType == distrType) {
14741470
Value newInsert;
1471+
SmallVector<OpFoldResult> indices;
14751472
if (pos) {
1476-
newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
1477-
distributedVec, pos);
1478-
} else {
1479-
newInsert = rewriter.create<vector::InsertOp>(
1480-
loc, newSource, distributedVec, ArrayRef<int64_t>{});
1473+
indices.push_back(pos);
14811474
}
1475+
newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
1476+
distributedVec, indices);
14821477
// Broadcast: Simply move the vector.insert op out.
14831478
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
14841479
newInsert);
@@ -1640,15 +1635,13 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
16401635
if (!operand)
16411636
return failure();
16421637
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1643-
rewriter.setInsertionPoint(insertOp);
1638+
SmallVector<OpFoldResult> indices;
16441639
if (auto pos = insertOp.getPosition()) {
1645-
rewriter.replaceOpWithNewOp<vector::InsertOp>(
1646-
insertOp, insertOp.getSource(), insertOp.getDest(), pos);
1647-
} else {
1648-
rewriter.replaceOpWithNewOp<vector::InsertOp>(
1649-
insertOp, insertOp.getSource(), insertOp.getDest(),
1650-
ArrayRef<int64_t>{});
1640+
indices.push_back(pos);
16511641
}
1642+
rewriter.setInsertionPoint(insertOp);
1643+
rewriter.replaceOpWithNewOp<vector::InsertOp>(
1644+
insertOp, insertOp.getSource(), insertOp.getDest(), indices);
16521645
return success();
16531646
}
16541647
};

0 commit comments

Comments
 (0)