Skip to content

Commit c15e7dd

Browse files
committed
fixup! [mlir][vector] Restrict use of 0-D vectors in vector.insert/vector.extract
Address Kunwar's comments
1 parent efc29a7 commit c15e7dd

File tree

2 files changed

+17
-34
lines changed

2 files changed

+17
-34
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def Vector_ExtractOp :
691691
InferTypeOpAdaptorWithIsCompatible]> {
692692
let summary = "extract operation";
693693
let description = [{
694-
Extracts an (n − k)-D subvector (the result) from an n-D vector at a
694+
Extracts an (n − k)-D result sub-vector from an n-D source vector at a
695695
specified k-D position. When n = k, the result degenerates to a scalar
696696
element.
697697

@@ -886,10 +886,9 @@ def Vector_InsertOp :
886886
AllTypesMatch<["dest", "result"]>]> {
887887
let summary = "insert operation";
888888
let description = [{
889-
Inserts an n-D source vector (the value to store) into an (n + k)-D
890-
destination vector at a specified k-D position. When n = 0, the source
891-
degenerates to a scalar element inserted into the (0 + k)-D destination
892-
vector.
889+
Inserts an n-D value-to-store vector into an (n + k)-D destination vector
890+
at a specified k-D position. When n = 0, value-to-store degenerates to
891+
a scalar element inserted into the (0 + k)-D destination vector.
893892

894893
Static and dynamic indices must be greater or equal to zero and less than
895894
the size of the corresponding dimension. The result is undefined if any

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,10 +1294,6 @@ struct UnrollTransferReadConversion
12941294

12951295
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
12961296
/// accesses, and broadcasts and transposes in permutation maps.
1297-
///
1298-
/// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
1299-
/// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
1300-
/// MemRef and Tensor source, respectively).
13011297
LogicalResult matchAndRewrite(TransferReadOp xferOp,
13021298
PatternRewriter &rewriter) const override {
13031299
if (xferOp.getVectorType().getRank() <= options.targetRank)
@@ -1345,32 +1341,20 @@ struct UnrollTransferReadConversion
13451341

13461342
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
13471343

1348-
// A value that's read after rank-reducing the original
1349-
// vector.transfer_read Op.
1350-
Value unpackedReadRes;
1351-
if (newXferVecType.getRank() != 0) {
1352-
// Unpacking Vector that's rank > 2
1353-
// (use vector.transfer_read to load a rank-reduced vector)
1354-
unpackedReadRes = b.create<vector::TransferReadOp>(
1355-
loc, newXferVecType, xferOp.getBase(), xferIndices,
1356-
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1357-
xferOp.getPadding(), Value(), inBoundsAttr);
1358-
maybeAssignMask(b, xferOp,
1359-
dyn_cast<vector::TransferReadOp>(
1360-
unpackedReadRes.getDefiningOp()),
1361-
i);
1362-
} else {
1363-
// Unpacking Vector that's rank == 1
1364-
// (use memref.load/tensor.extract to load a scalar)
1365-
unpackedReadRes = dyn_cast<MemRefType>(xferOp.getBase().getType())
1366-
? b.create<memref::LoadOp>(
1367-
loc, xferOp.getBase(), xferIndices)
1368-
.getResult()
1369-
: b.create<tensor::ExtractOp>(
1370-
loc, xferOp.getBase(), xferIndices)
1371-
.getResult();
1344+
auto newXferOp = b.create<vector::TransferReadOp>(
1345+
loc, newXferVecType, xferOp.getBase(), xferIndices,
1346+
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1347+
xferOp.getPadding(), Value(), inBoundsAttr);
1348+
maybeAssignMask(b, xferOp, newXferOp, i);
1349+
1350+
Value valToInser = newXferOp.getResult();
1351+
if (newXferVecType.getRank() == 0) {
1352+
// vector.insert does not accept rank-0 as the non-indexed
1353+
// argument. Extract the scalar before inserting.
1354+
valToInser = b.create<vector::ExtractOp>(loc, valToInser,
1355+
SmallVector<int64_t>());
13721356
}
1373-
return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
1357+
return b.create<vector::InsertOp>(loc, valToInser, vec,
13741358
insertionIndices);
13751359
},
13761360
/*outOfBoundsCase=*/

0 commit comments

Comments
 (0)