Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -691,8 +691,9 @@ def Vector_ExtractOp :
InferTypeOpAdaptorWithIsCompatible]> {
let summary = "extract operation";
let description = [{
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
the proper position. Degenerates to an element type if n-k is zero.
Extracts an (n − k)-D result sub-vector from an n-D source vector at a
specified k-D position. When n = k, the result degenerates to a scalar
element.

Static and dynamic indices must be greater or equal to zero and less than
the size of the corresponding dimension. The result is undefined if any
Expand All @@ -704,7 +705,6 @@ def Vector_ExtractOp :
```mlir
%1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
%2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
%3 = vector.extract %1[]: vector<f32> from vector<f32>
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
%6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
Expand Down Expand Up @@ -886,9 +886,9 @@ def Vector_InsertOp :
AllTypesMatch<["dest", "result"]>]> {
let summary = "insert operation";
let description = [{
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
and inserts the n-D source into the (n+k)-D destination at the proper
position. Degenerates to a scalar or a 0-d vector source type when n = 0.
Inserts an (n - k)-D sub-vector (value-to-store) into an n-D destination
vector at a specified k-D position. When n = 0, value-to-store degenerates
to a scalar element inserted into the n-D destination vector.

Static and dynamic indices must be greater or equal to zero and less than
the size of the corresponding dimension. The result is undefined if any
Expand All @@ -900,8 +900,7 @@ def Vector_InsertOp :
```mlir
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
%5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
%8 = vector.insert %6, %7[] : f32 into vector<f32>
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
%11 = vector.insert %9, %10[%a, %b, %c] : f32 into vector<4x8x16xf32>
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
%13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
```
Expand Down
13 changes: 12 additions & 1 deletion mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,8 @@ struct UnrollTransferReadConversion
for (int64_t i = 0; i < dimSize; ++i) {
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);

// FIXME: Rename this lambda - it does much more than just
// in-bounds-check generation.
vec = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
/*inBoundsCase=*/
Expand All @@ -1338,12 +1340,21 @@ struct UnrollTransferReadConversion
insertionIndices.push_back(rewriter.getIndexAttr(i));

auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());

auto newXferOp = b.create<vector::TransferReadOp>(
loc, newXferVecType, xferOp.getBase(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
xferOp.getPadding(), Value(), inBoundsAttr);
maybeAssignMask(b, xferOp, newXferOp, i);
return b.create<vector::InsertOp>(loc, newXferOp, vec,

Value valToInser = newXferOp.getResult();
if (newXferVecType.getRank() == 0) {
// vector.insert does not accept rank-0 as the non-indexed
// argument. Extract the scalar before inserting.
valToInser = b.create<vector::ExtractOp>(loc, valToInser,
SmallVector<int64_t>());
}
return b.create<vector::InsertOp>(loc, valToInser, vec,
insertionIndices);
},
/*outOfBoundsCase=*/
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}

LogicalResult vector::ExtractOp::verify() {
if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
if (resTy.getRank() == 0)
return emitError(
"expected a scalar instead of a 0-d vector as the result type");

// Note: This check must come before getMixedPosition() to prevent a crash.
auto dynamicMarkersCount =
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
Expand Down Expand Up @@ -3122,6 +3127,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
}

LogicalResult InsertOp::verify() {
if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
if (srcTy.getRank() == 0)
return emitError(
"expected a scalar instead of a 0-d vector as the source operand");

SmallVector<OpFoldResult> position = getMixedPosition();
auto destVectorType = getDestVectorType();
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
Expand Down
19 changes: 6 additions & 13 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {

// -----

func.func @extract_0d(%arg0: vector<f32>) {
// expected-error@+1 {{expected position attribute of rank no greater than vector rank}}
%1 = vector.extract %arg0[0] : f32 from vector<f32>
func.func @extract_0d_result(%arg0: vector<f32>) {
// expected-error@+1 {{expected a scalar instead of a 0-d vector as the result type}}
%1 = vector.extract %arg0[] : vector<f32> from vector<f32>
}

// -----
Expand Down Expand Up @@ -259,16 +259,9 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {

// -----

func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
%1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
}

// -----

func.func @insert_0d(%a: f32, %b: vector<f32>) {
// expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}}
%1 = vector.insert %a, %b[0] : f32 into vector<f32>
func.func @insert_0d_value_to_store(%a: vector<f32>, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected a scalar instead of a 0-d vector as the source operand}}
%1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
}

// -----
Expand Down
6 changes: 2 additions & 4 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
}

// CHECK-LABEL: @insert_0d
func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
%1 = vector.insert %a, %b[] : f32 into vector<f32>
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
%2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
return %1, %2 : vector<f32>, vector<2x3xf32>
return %1 : vector<f32>
}

// CHECK-LABEL: @insert_poison_idx
Expand Down
Loading