@@ -150,13 +150,18 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
150150 return false ;
151151}
152152
153- AffineMap mlir::vector::getTransferMinorIdentityMap (ShapedType shapedType,
154- VectorType vectorType) {
155- int64_t elementVectorRank = 0 ;
153+ static unsigned getRealVectorRank (ShapedType shapedType,
154+ VectorType vectorType) {
155+ unsigned elementVectorRank = 0 ;
156156 VectorType elementVectorType =
157157 llvm::dyn_cast<VectorType>(shapedType.getElementType ());
158158 if (elementVectorType)
159159 elementVectorRank += elementVectorType.getRank ();
160+ return vectorType.getRank () - elementVectorRank;
161+ }
162+
163+ AffineMap mlir::vector::getTransferMinorIdentityMap (ShapedType shapedType,
164+ VectorType vectorType) {
160165 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
161166 // TODO: replace once we have 0-d vectors.
162167 if (shapedType.getRank () == 0 &&
@@ -165,7 +170,7 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
165170 /* numDims=*/ 0 , /* numSymbols=*/ 0 ,
166171 getAffineConstantExpr (0 , shapedType.getContext ()));
167172 return AffineMap::getMinorIdentityMap (
168- shapedType.getRank (), vectorType. getRank () - elementVectorRank ,
173+ shapedType.getRank (), getRealVectorRank (shapedType, vectorType) ,
169174 shapedType.getContext ());
170175}
171176
@@ -4259,12 +4264,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
42594264 Attribute permMapAttr = result.attributes .get (permMapAttrName);
42604265 AffineMap permMap;
42614266 if (!permMapAttr) {
4262- int64_t elementVectorRank = 0 ;
4263- VectorType elementVectorType =
4264- llvm::dyn_cast<VectorType>(shapedType.getElementType ());
4265- if (elementVectorType)
4266- elementVectorRank += elementVectorType.getRank ();
4267- if (shapedType.getRank () < vectorType.getRank () - elementVectorRank)
4267+ if (shapedType.getRank () < getRealVectorRank (shapedType, vectorType))
42684268 return parser.emitError (typesLoc,
42694269 " expected a custom permutation_map when "
42704270 " rank(source) != rank(destination)" );
@@ -4676,12 +4676,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
46764676 auto permMapAttr = result.attributes .get (permMapAttrName);
46774677 AffineMap permMap;
46784678 if (!permMapAttr) {
4679- int64_t elementVectorRank = 0 ;
4680- VectorType elementVectorType =
4681- llvm::dyn_cast<VectorType>(shapedType.getElementType ());
4682- if (elementVectorType)
4683- elementVectorRank += elementVectorType.getRank ();
4684- if (shapedType.getRank () < vectorType.getRank () - elementVectorRank)
4679+ if (shapedType.getRank () < getRealVectorRank (shapedType, vectorType))
46854680 return parser.emitError (typesLoc,
46864681 " expected a custom permutation_map when "
46874682 " rank(source) != rank(destination)" );
0 commit comments