Skip to content

Commit f671fe9

Browse files
committed
[mlir][vector] Fix parser of vector.transfer_read
1 parent 3a8d23c commit f671fe9

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4265,8 +4265,15 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
42654265
if (!permMapAttr) {
42664266
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
42674267
if (!permMap) {
4268-
return parser.emitError(
4269-
typesLoc, "failed to create a minor identity map, source rank is less than required for vector rank");
4268+
int64_t elementVectorRank = 0;
4269+
VectorType elementVectorType =
4270+
llvm::dyn_cast<VectorType>(shapedType.getElementType());
4271+
if (elementVectorType)
4272+
elementVectorRank += elementVectorType.getRank();
4273+
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
4274+
return parser.emitError(typesLoc,
4275+
"expected a custom permutation_map when source "
4276+
"rank is less than required for vector rank");
42704277
}
42714278
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
42724279
} else {
@@ -4677,8 +4684,15 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
46774684
if (!permMapAttr) {
46784685
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
46794686
if (!permMap) {
4680-
return parser.emitError(
4681-
typesLoc, "failed to create a minor identity map, result rank is less than required for vector rank");
4687+
int64_t elementVectorRank = 0;
4688+
VectorType elementVectorType =
4689+
llvm::dyn_cast<VectorType>(shapedType.getElementType());
4690+
if (elementVectorType)
4691+
elementVectorRank += elementVectorType.getRank();
4692+
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
4693+
return parser.emitError(typesLoc,
4694+
"expected a custom permutation_map when result "
4695+
"rank is less than required for vector rank");
46824696
}
46834697
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
46844698
} else {

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
527527

528528
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
529529
%c3 = arith.constant 3 : index
530-
// expected-error@+1 {{failed to create a minor identity map, source rank is less than required for vector rank}}
530+
// expected-error@+1 {{expected a custom permutation_map when source rank is less than required for vector rank}}
531531
%0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xi32>
532532
return %0 : vector<3x4xi32>
533533
}
@@ -656,9 +656,9 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
656656
// -----
657657

658658
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
659-
%c3_idx = arith.constant 3 : index
660-
// expected-error@+1 {{failed to create a minor identity map, result rank is less than required for vector rank}}
661-
vector.transfer_write %vec_to_write, %output_memref[%c3_idx, %c3_idx] : vector<3x4xi32>, memref<?xindex>
659+
%c3 = arith.constant 3 : index
660+
// expected-error@+1 {{expected a custom permutation_map when result rank is less than required for vector rank}}
661+
vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xi32>, memref<?xindex>
662662
}
663663

664664
// -----

0 commit comments

Comments
 (0)