Skip to content

Commit 2a920e0

Browse files
committed
[mlir][vector] Fix parser of vector.transfer_read
1 parent bbe35d6 commit 2a920e0

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,18 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
150150
return false;
151151
}
152152

153-
AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
154-
VectorType vectorType) {
155-
int64_t elementVectorRank = 0;
153+
static unsigned getRealVectorRank(ShapedType shapedType,
154+
VectorType vectorType) {
155+
unsigned elementVectorRank = 0;
156156
VectorType elementVectorType =
157157
llvm::dyn_cast<VectorType>(shapedType.getElementType());
158158
if (elementVectorType)
159159
elementVectorRank += elementVectorType.getRank();
160+
return vectorType.getRank() - elementVectorRank;
161+
}
162+
163+
AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
164+
VectorType vectorType) {
160165
// 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
161166
// TODO: replace once we have 0-d vectors.
162167
if (shapedType.getRank() == 0 &&
@@ -165,7 +170,7 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
165170
/*numDims=*/0, /*numSymbols=*/0,
166171
getAffineConstantExpr(0, shapedType.getContext()));
167172
return AffineMap::getMinorIdentityMap(
168-
shapedType.getRank(), vectorType.getRank() - elementVectorRank,
173+
shapedType.getRank(), getRealVectorRank(shapedType, vectorType),
169174
shapedType.getContext());
170175
}
171176

@@ -4259,12 +4264,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
42594264
Attribute permMapAttr = result.attributes.get(permMapAttrName);
42604265
AffineMap permMap;
42614266
if (!permMapAttr) {
4262-
int64_t elementVectorRank = 0;
4263-
VectorType elementVectorType =
4264-
llvm::dyn_cast<VectorType>(shapedType.getElementType());
4265-
if (elementVectorType)
4266-
elementVectorRank += elementVectorType.getRank();
4267-
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
4267+
if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
42684268
return parser.emitError(typesLoc,
42694269
"expected a custom permutation_map when "
42704270
"rank(source) != rank(destination)");
@@ -4676,12 +4676,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
46764676
auto permMapAttr = result.attributes.get(permMapAttrName);
46774677
AffineMap permMap;
46784678
if (!permMapAttr) {
4679-
int64_t elementVectorRank = 0;
4680-
VectorType elementVectorType =
4681-
llvm::dyn_cast<VectorType>(shapedType.getElementType());
4682-
if (elementVectorType)
4683-
elementVectorRank += elementVectorType.getRank();
4684-
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
4679+
if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
46854680
return parser.emitError(typesLoc,
46864681
"expected a custom permutation_map when "
46874682
"rank(source) != rank(destination)");

0 commit comments

Comments
 (0)