diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3fee1e949aeed..49dd433597e8c 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -151,29 +151,32 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind, return false; } -/// Returns the number of dimensions of the `shapedType` that participate in the -/// vector transfer, effectively the rank of the vector dimensions within the -/// `shapedType`. This is calculated by taking the rank of the `vectorType` -/// being transferred and subtracting the rank of the `shapedType`'s element -/// type if it's also a vector. +/// Returns the effective rank of the vector to read/write for Xfer Ops /// -/// This is used to determine the number of minor dimensions for identity maps -/// in vector transfers. +/// When the element type of the shaped type is _a scalar_, this will simply +/// return the rank of the vector ( the result for xfer_read or the value to +/// store for xfer_write). /// -/// For example, given a transfer operation involving `shapedType` and -/// `vectorType`: +/// When the element type of the base shaped type is _a vector_, returns the +/// difference between the original vector type and the element type of the +/// shaped type. /// +/// EXAMPLE 1 (element type is _a scalar_): /// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32> /// - shapedType.getElementType() = f32 (rank 0) /// - vectorType.getRank() = 2 /// - Result = 2 - 0 = 2 /// +/// EXAMPLE 2 (element type is _a vector_): /// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32> /// - shapedType.getElementType() = vector<20xf32> (rank 1) /// - vectorType.getRank() = 1 /// - Result = 1 - 1 = 0 -static unsigned getRealVectorRank(ShapedType shapedType, - VectorType vectorType) { +/// +/// This is used to determine the number of minor dimensions for identity maps +/// in vector transfer Ops. +static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, + VectorType vectorType) { unsigned elementVectorRank = 0; VectorType elementVectorType = llvm::dyn_cast(shapedType.getElementType()); @@ -192,7 +195,8 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, /*numDims=*/0, /*numSymbols=*/0, getAffineConstantExpr(0, shapedType.getContext())); return AffineMap::getMinorIdentityMap( - shapedType.getRank(), getRealVectorRank(shapedType, vectorType), + shapedType.getRank(), + getEffectiveVectorRankForXferOp(shapedType, vectorType), shapedType.getContext()); } @@ -4260,7 +4264,8 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { Attribute permMapAttr = result.attributes.get(permMapAttrName); AffineMap permMap; if (!permMapAttr) { - if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType)) + if (shapedType.getRank() < + getEffectiveVectorRankForXferOp(shapedType, vectorType)) return parser.emitError(typesLoc, "expected a custom permutation_map when " "rank(source) != rank(destination)"); @@ -4679,7 +4684,8 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser, auto permMapAttr = result.attributes.get(permMapAttrName); AffineMap permMap; if (!permMapAttr) { - if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType)) + if (shapedType.getRank() < + getEffectiveVectorRankForXferOp(shapedType, vectorType)) return parser.emitError(typesLoc, "expected a custom permutation_map when " "rank(source) != rank(destination)"); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 19096f0e4c895..349a58d4eb4e4 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -525,15 +525,24 @@ func.func @test_vector.transfer_read(%arg0: memref>) { // ----- -func.func @test_vector.transfer_read(%arg1: memref) -> vector<3x4xindex> { +func.func @test_vector.transfer_read(%arg0: memref) -> vector<3x4xindex> { %c3 = arith.constant 3 : index // expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}} - %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref, vector<3x4xindex> + %0 = vector.transfer_read %arg0[%c3], %c3 : memref, vector<3x4xindex> return %0 : vector<3x4xindex> } // ----- +func.func @test_vector.transfer_write(%arg0: memref>) { + %c3 = arith.constant 3 : index + // expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}} + %0 = vector.transfer_read %arg0[%c3], %c3 : memref>, vector<2x3x4xindex> + return %0 : vector<2x3x4xindex> +} + +// ----- + func.func @test_vector.transfer_write(%arg0: memref) { %c3 = arith.constant 3 : index %cst = arith.constant 3.0 : f32 @@ -655,10 +664,18 @@ func.func @test_vector.transfer_write(%arg0: memref, %arg1: vector<7xf32> // ----- -func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref) { +func.func @test_vector.transfer_write(%arg0: memref, %arg1: vector<3x4xindex>) { + %c3 = arith.constant 3 : index + // expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}} + vector.transfer_write %arg1, %arg0[%c3, %c3] : vector<3x4xindex>, memref +} + +// ----- + +func.func @test_vector.transfer_write(%arg0: memref>, %arg1: vector<2x3x4xindex>) { %c3 = arith.constant 3 : index // expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}} - vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xindex>, memref + vector.transfer_write %arg1, %arg0[%c3, %c3] : vector<2x3x4xindex>, memref> } // -----