Skip to content

Commit bbe35d6

Browse files
committed
[mlir][vector] Fix parser of vector.transfer_read
1 parent f671fe9 commit bbe35d6

File tree

2 files changed

+42
-33
lines changed

2 files changed

+42
-33
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -152,22 +152,18 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
152152

153153
AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
154154
VectorType vectorType) {
155+
int64_t elementVectorRank = 0;
156+
VectorType elementVectorType =
157+
llvm::dyn_cast<VectorType>(shapedType.getElementType());
158+
if (elementVectorType)
159+
elementVectorRank += elementVectorType.getRank();
155160
// 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
156161
// TODO: replace once we have 0-d vectors.
157162
if (shapedType.getRank() == 0 &&
158163
vectorType.getShape() == ArrayRef<int64_t>{1})
159164
return AffineMap::get(
160165
/*numDims=*/0, /*numSymbols=*/0,
161166
getAffineConstantExpr(0, shapedType.getContext()));
162-
int64_t elementVectorRank = 0;
163-
VectorType elementVectorType =
164-
llvm::dyn_cast<VectorType>(shapedType.getElementType());
165-
if (elementVectorType)
166-
elementVectorRank += elementVectorType.getRank();
167-
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
168-
// Not enough dimensions in the shaped type to form a minor identity map.
169-
return AffineMap();
170-
}
171167
return AffineMap::getMinorIdentityMap(
172168
shapedType.getRank(), vectorType.getRank() - elementVectorRank,
173169
shapedType.getContext());
@@ -4263,18 +4259,16 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
42634259
Attribute permMapAttr = result.attributes.get(permMapAttrName);
42644260
AffineMap permMap;
42654261
if (!permMapAttr) {
4262+
int64_t elementVectorRank = 0;
4263+
VectorType elementVectorType =
4264+
llvm::dyn_cast<VectorType>(shapedType.getElementType());
4265+
if (elementVectorType)
4266+
elementVectorRank += elementVectorType.getRank();
4267+
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
4268+
return parser.emitError(typesLoc,
4269+
"expected a custom permutation_map when "
4270+
"rank(source) != rank(destination)");
42664271
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4267-
if (!permMap) {
4268-
int64_t elementVectorRank = 0;
4269-
VectorType elementVectorType =
4270-
llvm::dyn_cast<VectorType>(shapedType.getElementType());
4271-
if (elementVectorType)
4272-
elementVectorRank += elementVectorType.getRank();
4273-
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
4274-
return parser.emitError(typesLoc,
4275-
"expected a custom permutation_map when source "
4276-
"rank is less than required for vector rank");
4277-
}
42784272
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
42794273
} else {
42804274
permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
@@ -4682,18 +4676,16 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
46824676
auto permMapAttr = result.attributes.get(permMapAttrName);
46834677
AffineMap permMap;
46844678
if (!permMapAttr) {
4679+
int64_t elementVectorRank = 0;
4680+
VectorType elementVectorType =
4681+
llvm::dyn_cast<VectorType>(shapedType.getElementType());
4682+
if (elementVectorType)
4683+
elementVectorRank += elementVectorType.getRank();
4684+
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
4685+
return parser.emitError(typesLoc,
4686+
"expected a custom permutation_map when "
4687+
"rank(source) != rank(destination)");
46854688
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4686-
if (!permMap) {
4687-
int64_t elementVectorRank = 0;
4688-
VectorType elementVectorType =
4689-
llvm::dyn_cast<VectorType>(shapedType.getElementType());
4690-
if (elementVectorType)
4691-
elementVectorRank += elementVectorType.getRank();
4692-
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
4693-
return parser.emitError(typesLoc,
4694-
"expected a custom permutation_map when result "
4695-
"rank is less than required for vector rank");
4696-
}
46974689
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
46984690
} else {
46994691
permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,13 +527,22 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
527527

528528
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
529529
%c3 = arith.constant 3 : index
530-
// expected-error@+1 {{expected a custom permutation_map when source rank is less than required for vector rank}}
530+
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
531531
%0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xi32>
532532
return %0 : vector<3x4xi32>
533533
}
534534

535535
// -----
536536

537+
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xindex> {
538+
%c3 = arith.constant 3 : index
539+
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
540+
%0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xindex>
541+
return %0 : vector<3x4xindex>
542+
}
543+
544+
// -----
545+
537546
func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
538547
%c3 = arith.constant 3 : index
539548
%cst = arith.constant 3.0 : f32
@@ -657,12 +666,20 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
657666

658667
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
659668
%c3 = arith.constant 3 : index
660-
// expected-error@+1 {{expected a custom permutation_map when result rank is less than required for vector rank}}
669+
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
661670
vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xi32>, memref<?xindex>
662671
}
663672

664673
// -----
665674

675+
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref<?xindex>) {
676+
%c3 = arith.constant 3 : index
677+
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
678+
vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xindex>, memref<?xindex>
679+
}
680+
681+
// -----
682+
666683
func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
667684
// expected-error@+1 {{expected offsets of same size as destination vector rank}}
668685
%1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>

0 commit comments

Comments
 (0)