Skip to content

Commit 0962edb

Browse files
committed
[mlir][vector] Restrict vector.insert/vector.extract
This patch restricts the use of vector.insert and vector.extract Ops in the Vector dialect. Specifically: * The non-indexed operands for `vector.insert` and `vector.extract` must now be non-0-D vectors. The following are now illegal. Note that the source and result types (i.e. non-indexed args) are rank-0 vectors: ```mlir %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32> %1 = vector.extract %arg0[0, 0] : vector<f32> from vector<2x2xf32> ``` Instead, use scalars as the source and result types: ```mlir %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32> %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32> ``` Put differently, this PR removes the ambiguity when it comes to non-indexed operands of `vector.insert` and `vector.extract`. By requiring that only one form is used, it eliminates the flexibility of allowing both, thereby simplifying the semantics. For more context, see the related RFC: * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
1 parent 45e874e commit 0962edb

File tree

4 files changed

+30
-12
lines changed

4 files changed

+30
-12
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,12 +1333,22 @@ struct UnrollTransferReadConversion
13331333
insertionIndices.push_back(rewriter.getIndexAttr(i));
13341334

13351335
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1336-
auto newXferOp = b.create<vector::TransferReadOp>(
1337-
loc, newXferVecType, xferOp.getSource(), xferIndices,
1338-
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1339-
xferOp.getPadding(), Value(), inBoundsAttr);
1340-
maybeAssignMask(b, xferOp, newXferOp, i);
1341-
return b.create<vector::InsertOp>(loc, newXferOp, vec,
1336+
Value newXferOrLoadOp;
1337+
if (newXferVecType.getRank() != 0) {
1338+
newXferOrLoadOp = b.create<vector::TransferReadOp>(
1339+
loc, newXferVecType, xferOp.getSource(), xferIndices,
1340+
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1341+
xferOp.getPadding(), Value(), inBoundsAttr);
1342+
maybeAssignMask(b, xferOp,
1343+
dyn_cast<vector::TransferReadOp>(
1344+
newXferOrLoadOp.getDefiningOp()),
1345+
i);
1346+
} else {
1347+
// TODO: Generalize so that this also works for Tensors.
1348+
newXferOrLoadOp = b.create<memref::LoadOp>(
1349+
loc, xferOp.getSource(), xferIndices);
1350+
}
1351+
return b.create<vector::InsertOp>(loc, newXferOrLoadOp, vec,
13421352
insertionIndices);
13431353
},
13441354
/*outOfBoundsCase=*/

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
13401340
}
13411341

13421342
LogicalResult vector::ExtractOp::verify() {
1343+
if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1344+
if (resTy.getRank() == 0)
1345+
return emitError(
1346+
"expected a scalar instead of a 0-d vector as the result type");
1347+
13431348
// Note: This check must come before getMixedPosition() to prevent a crash.
13441349
auto dynamicMarkersCount =
13451350
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -2864,6 +2869,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
28642869
}
28652870

28662871
LogicalResult InsertOp::verify() {
2872+
if (auto srcTy = dyn_cast<VectorType>(getSourceType()))
2873+
if (srcTy.getRank() == 0)
2874+
return emitError(
2875+
"expected a scalar instead of a 0-d vector as the source operand");
2876+
28672877
SmallVector<OpFoldResult> position = getMixedPosition();
28682878
auto destVectorType = getDestVectorType();
28692879
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
260260
// -----
261261

262262
func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
263-
// expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
264-
%1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
263+
// expected-error@+1 {{expected a scalar instead of a 0-d vector as the source operand}}
264+
%1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
265265
}
266266

267267
// -----

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
284284
}
285285

286286
// CHECK-LABEL: @insert_0d
287-
func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
287+
func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
288288
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
289289
%1 = vector.insert %a, %b[] : f32 into vector<f32>
290-
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
291-
%2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
292-
return %1, %2 : vector<f32>, vector<2x3xf32>
290+
return %1 : vector<f32>
293291
}
294292

295293
// CHECK-LABEL: @outerproduct

0 commit comments

Comments
 (0)