Skip to content

Commit efc29a7

Browse files
committed
[mlir][vector] Restrict use of 0-D vectors in vector.insert/vector.extract
This patch enforces a restriction in the Vector dialect: the non-indexed operands of `vector.insert` and `vector.extract` must no longer be 0-D vectors. In other words, rank-0 vector types like `vector<f32>` are disallowed as the source or result. EXAMPLES -------- The following are now **illegal** (note the use of `vector<f32>`): ```mlir %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32> %1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32> ``` Instead, use scalars as the source and result types: ```mlir %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32> %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32> ``` This change serves three goals: 1. REDUCED AMBIGUITY -------------------- By enforcing scalar-only semantics when n-k = 0, we eliminate ambiguity in interpretation. Prior to this patch, both `f32` and `vector<f32>` were accepted in practice, though only scalars were intended. 2. MATCH IMPLEMENTATION TO DOCUMENTATION ---------------------------------------- The current behavior contradicts the documented intent. For example, vector.extract states: > Degenerates to an element type if n-k is zero. This patch enforces that intent in code. 3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT --------------------------------------------- With the stricter semantics in place, it’s natural and consistent to make `vector.insert` behave symmetrically to `vector.extract`, i.e., degenerate the source type to a scalar when n = 0. NOTES FOR REVIEWERS ------------------- 1. Main change is in "VectorOps.cpp", where stricter type checks are implemented. 2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to remove now-illegal examples. 2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now avoid using `vector.transfer_read` for scalar loads and instead rely on `memref.load` / `tensor.extract`. RELATED RFC ----------- * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
1 parent e4de74b commit efc29a7

File tree

5 files changed

+59
-31
lines changed

5 files changed

+59
-31
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -691,8 +691,9 @@ def Vector_ExtractOp :
691691
InferTypeOpAdaptorWithIsCompatible]> {
692692
let summary = "extract operation";
693693
let description = [{
694-
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
695-
the proper position. Degenerates to an element type if n-k is zero.
694+
Extracts an (n − k)-D subvector (the result) from an n-D vector at a
695+
specified k-D position. When n = k, the result degenerates to a scalar
696+
element.
696697

697698
Static and dynamic indices must be greater or equal to zero and less than
698699
the size of the corresponding dimension. The result is undefined if any
@@ -704,7 +705,6 @@ def Vector_ExtractOp :
704705
```mlir
705706
%1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
706707
%2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
707-
%3 = vector.extract %1[]: vector<f32> from vector<f32>
708708
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
709709
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
710710
%6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
@@ -886,9 +886,10 @@ def Vector_InsertOp :
886886
AllTypesMatch<["dest", "result"]>]> {
887887
let summary = "insert operation";
888888
let description = [{
889-
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
890-
and inserts the n-D source into the (n+k)-D destination at the proper
891-
position. Degenerates to a scalar or a 0-d vector source type when n = 0.
889+
Inserts an n-D source vector (the value to store) into an (n + k)-D
890+
destination vector at a specified k-D position. When n = 0, the source
891+
degenerates to a scalar element inserted into the (0 + k)-D destination
892+
vector.
892893

893894
Static and dynamic indices must be greater or equal to zero and less than
894895
the size of the corresponding dimension. The result is undefined if any
@@ -900,8 +901,7 @@ def Vector_InsertOp :
900901
```mlir
901902
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
902903
%5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
903-
%8 = vector.insert %6, %7[] : f32 into vector<f32>
904-
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
904+
%11 = vector.insert %9, %10[%a, %b, %c] : f32 into vector<4x8x16xf32>
905905
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
906906
%13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
907907
```

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,10 @@ struct UnrollTransferReadConversion
12941294

12951295
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
12961296
/// accesses, and broadcasts and transposes in permutation maps.
1297+
///
1298+
/// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
1299+
/// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
1300+
/// MemRef and Tensor source, respectively).
12971301
LogicalResult matchAndRewrite(TransferReadOp xferOp,
12981302
PatternRewriter &rewriter) const override {
12991303
if (xferOp.getVectorType().getRank() <= options.targetRank)
@@ -1324,6 +1328,8 @@ struct UnrollTransferReadConversion
13241328
for (int64_t i = 0; i < dimSize; ++i) {
13251329
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
13261330

1331+
// FIXME: Rename this lambda - it does much more than just
1332+
// in-bounds-check generation.
13271333
vec = generateInBoundsCheck(
13281334
rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
13291335
/*inBoundsCase=*/
@@ -1338,12 +1344,33 @@ struct UnrollTransferReadConversion
13381344
insertionIndices.push_back(rewriter.getIndexAttr(i));
13391345

13401346
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1341-
auto newXferOp = b.create<vector::TransferReadOp>(
1342-
loc, newXferVecType, xferOp.getBase(), xferIndices,
1343-
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1344-
xferOp.getPadding(), Value(), inBoundsAttr);
1345-
maybeAssignMask(b, xferOp, newXferOp, i);
1346-
return b.create<vector::InsertOp>(loc, newXferOp, vec,
1347+
1348+
// A value that's read after rank-reducing the original
1349+
// vector.transfer_read Op.
1350+
Value unpackedReadRes;
1351+
if (newXferVecType.getRank() != 0) {
1352+
// Unpacking Vector that's rank > 2
1353+
// (use vector.transfer_read to load a rank-reduced vector)
1354+
unpackedReadRes = b.create<vector::TransferReadOp>(
1355+
loc, newXferVecType, xferOp.getBase(), xferIndices,
1356+
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1357+
xferOp.getPadding(), Value(), inBoundsAttr);
1358+
maybeAssignMask(b, xferOp,
1359+
dyn_cast<vector::TransferReadOp>(
1360+
unpackedReadRes.getDefiningOp()),
1361+
i);
1362+
} else {
1363+
// Unpacking Vector that's rank == 1
1364+
// (use memref.load/tensor.extract to load a scalar)
1365+
unpackedReadRes = dyn_cast<MemRefType>(xferOp.getBase().getType())
1366+
? b.create<memref::LoadOp>(
1367+
loc, xferOp.getBase(), xferIndices)
1368+
.getResult()
1369+
: b.create<tensor::ExtractOp>(
1370+
loc, xferOp.getBase(), xferIndices)
1371+
.getResult();
1372+
}
1373+
return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
13471374
insertionIndices);
13481375
},
13491376
/*outOfBoundsCase=*/

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
13831383
}
13841384

13851385
LogicalResult vector::ExtractOp::verify() {
1386+
if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1387+
if (resTy.getRank() == 0)
1388+
return emitError(
1389+
"expected a scalar instead of a 0-d vector as the result type");
1390+
13861391
// Note: This check must come before getMixedPosition() to prevent a crash.
13871392
auto dynamicMarkersCount =
13881393
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -3122,6 +3127,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
31223127
}
31233128

31243129
LogicalResult InsertOp::verify() {
3130+
if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3131+
if (srcTy.getRank() == 0)
3132+
return emitError(
3133+
"expected a scalar instead of a 0-d vector as the source operand");
3134+
31253135
SmallVector<OpFoldResult> position = getMixedPosition();
31263136
auto destVectorType = getDestVectorType();
31273137
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,9 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
178178

179179
// -----
180180

181-
func.func @extract_0d(%arg0: vector<f32>) {
182-
// expected-error@+1 {{expected position attribute of rank no greater than vector rank}}
183-
%1 = vector.extract %arg0[0] : f32 from vector<f32>
181+
func.func @extract_0d_result(%arg0: vector<f32>) {
182+
// expected-error@+1 {{expected a scalar instead of a 0-d vector as the result type}}
183+
%1 = vector.extract %arg0[] : vector<f32> from vector<f32>
184184
}
185185

186186
// -----
@@ -259,16 +259,9 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
259259

260260
// -----
261261

262-
func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
263-
// expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
264-
%1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
265-
}
266-
267-
// -----
268-
269-
func.func @insert_0d(%a: f32, %b: vector<f32>) {
270-
// expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}}
271-
%1 = vector.insert %a, %b[0] : f32 into vector<f32>
262+
func.func @insert_0d_value_to_store(%a: vector<f32>, %b: vector<4x8x16xf32>) {
263+
// expected-error@+1 {{expected a scalar instead of a 0-d vector as the source operand}}
264+
%1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
272265
}
273266

274267
// -----

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
298298
}
299299

300300
// CHECK-LABEL: @insert_0d
301-
func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
301+
func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
302302
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
303303
%1 = vector.insert %a, %b[] : f32 into vector<f32>
304-
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
305-
%2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
306-
return %1, %2 : vector<f32>, vector<2x3xf32>
304+
return %1 : vector<f32>
307305
}
308306

309307
// CHECK-LABEL: @insert_poison_idx

0 commit comments

Comments
 (0)