Skip to content

Commit 9434337

Browse files
committed
[mlir][vector] Restrict vector.insert/vector.extract
This patch restricts the use of vector.insert and vector.extract Ops in the Vector dialect. Specifically: * The non-indexed operands for `vector.insert` and `vector.extract` must now be non-0-D vectors. The following are now illegal. Note that the source and result types (i.e. non-indexed args) are rank-0 vectors: ```mlir %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32> %1 = vector.extract %arg0[0, 0] : vector<f32> from vector<2x2xf32> ``` Instead, use scalars as the source and result types: ```mlir %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32> %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32> ``` Put differently, this PR removes the ambiguity when it comes to non-indexed operands of `vector.insert` and `vector.extract`. By requiring that only one form is used, it eliminates the flexibility of allowing both, thereby simplifying the semantics. For more context, see the related RFC: * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
1 parent f99b190 commit 9434337

File tree

4 files changed

+48
-12
lines changed

4 files changed

+48
-12
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,10 @@ struct UnrollTransferReadConversion
12871287

12881288
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
12891289
/// accesses, and broadcasts and transposes in permutation maps.
1290+
///
1291+
/// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
1292+
/// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
1293+
/// MemRef and Tensor source, respectively).
12901294
LogicalResult matchAndRewrite(TransferReadOp xferOp,
12911295
PatternRewriter &rewriter) const override {
12921296
if (xferOp.getVectorType().getRank() <= options.targetRank)
@@ -1319,6 +1323,8 @@ struct UnrollTransferReadConversion
13191323
for (int64_t i = 0; i < dimSize; ++i) {
13201324
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
13211325

1326+
// FIXME: Rename this lambda - it does much more than just
1327+
// in-bounds-check generation.
13221328
vec = generateInBoundsCheck(
13231329
rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
13241330
/*inBoundsCase=*/
@@ -1333,12 +1339,34 @@ struct UnrollTransferReadConversion
13331339
insertionIndices.push_back(rewriter.getIndexAttr(i));
13341340

13351341
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1336-
auto newXferOp = b.create<vector::TransferReadOp>(
1337-
loc, newXferVecType, xferOp.getSource(), xferIndices,
1338-
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1339-
xferOp.getPadding(), Value(), inBoundsAttr);
1340-
maybeAssignMask(b, xferOp, newXferOp, i);
1341-
return b.create<vector::InsertOp>(loc, newXferOp, vec,
1342+
1343+
// A value that's read after rank-reducing the original
1344+
// vector.transfer_read Op.
1345+
Value unpackedReadRes;
1346+
if (newXferVecType.getRank() != 0) {
1347+
// Unpacking Vector that's rank > 2
1348+
// (use vector.transfer_read to load a rank-reduced vector)
1349+
unpackedReadRes = b.create<vector::TransferReadOp>(
1350+
loc, newXferVecType, xferOp.getSource(), xferIndices,
1351+
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1352+
xferOp.getPadding(), Value(), inBoundsAttr);
1353+
maybeAssignMask(b, xferOp,
1354+
dyn_cast<vector::TransferReadOp>(
1355+
unpackedReadRes.getDefiningOp()),
1356+
i);
1357+
} else {
1358+
// Unpacking Vector that's rank == 1
1359+
// (use memref.load/tensor.extract to load a scalar)
1360+
unpackedReadRes =
1361+
dyn_cast<MemRefType>(xferOp.getSource().getType())
1362+
? b.create<memref::LoadOp>(loc, xferOp.getSource(),
1363+
xferIndices)
1364+
.getResult()
1365+
: b.create<tensor::ExtractOp>(loc, xferOp.getSource(),
1366+
xferIndices)
1367+
.getResult();
1368+
}
1369+
return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
13421370
insertionIndices);
13431371
},
13441372
/*outOfBoundsCase=*/

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
13401340
}
13411341

13421342
LogicalResult vector::ExtractOp::verify() {
1343+
if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1344+
if (resTy.getRank() == 0)
1345+
return emitError(
1346+
"expected a scalar instead of a 0-d vector as the result type");
1347+
13431348
// Note: This check must come before getMixedPosition() to prevent a crash.
13441349
auto dynamicMarkersCount =
13451350
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -2864,6 +2869,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
28642869
}
28652870

28662871
LogicalResult InsertOp::verify() {
2872+
if (auto srcTy = dyn_cast<VectorType>(getSourceType()))
2873+
if (srcTy.getRank() == 0)
2874+
return emitError(
2875+
"expected a scalar instead of a 0-d vector as the source operand");
2876+
28672877
SmallVector<OpFoldResult> position = getMixedPosition();
28682878
auto destVectorType = getDestVectorType();
28692879
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
260260
// -----
261261

262262
func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
263-
// expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
264-
%1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
263+
// expected-error@+1 {{expected a scalar instead of a 0-d vector as the source operand}}
264+
%1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
265265
}
266266

267267
// -----

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
284284
}
285285

286286
// CHECK-LABEL: @insert_0d
287-
func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
287+
func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
288288
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
289289
%1 = vector.insert %a, %b[] : f32 into vector<f32>
290-
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
291-
%2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
292-
return %1, %2 : vector<f32>, vector<2x3xf32>
290+
return %1 : vector<f32>
293291
}
294292

295293
// CHECK-LABEL: @outerproduct

0 commit comments

Comments
 (0)