@@ -152,22 +152,18 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
152152
153153AffineMap mlir::vector::getTransferMinorIdentityMap (ShapedType shapedType,
154154 VectorType vectorType) {
155+ int64_t elementVectorRank = 0 ;
156+ VectorType elementVectorType =
157+ llvm::dyn_cast<VectorType>(shapedType.getElementType ());
158+ if (elementVectorType)
159+ elementVectorRank += elementVectorType.getRank ();
155160 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
156161 // TODO: replace once we have 0-d vectors.
157162 if (shapedType.getRank () == 0 &&
158163 vectorType.getShape () == ArrayRef<int64_t >{1 })
159164 return AffineMap::get (
160165 /* numDims=*/ 0 , /* numSymbols=*/ 0 ,
161166 getAffineConstantExpr (0 , shapedType.getContext ()));
162- int64_t elementVectorRank = 0 ;
163- VectorType elementVectorType =
164- llvm::dyn_cast<VectorType>(shapedType.getElementType ());
165- if (elementVectorType)
166- elementVectorRank += elementVectorType.getRank ();
167- if (shapedType.getRank () < vectorType.getRank () - elementVectorRank) {
168- // Not enough dimensions in the shaped type to form a minor identity map.
169- return AffineMap ();
170- }
171167 return AffineMap::getMinorIdentityMap (
172168 shapedType.getRank (), vectorType.getRank () - elementVectorRank,
173169 shapedType.getContext ());
@@ -4263,18 +4259,16 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
42634259 Attribute permMapAttr = result.attributes .get (permMapAttrName);
42644260 AffineMap permMap;
42654261 if (!permMapAttr) {
4262+ int64_t elementVectorRank = 0 ;
4263+ VectorType elementVectorType =
4264+ llvm::dyn_cast<VectorType>(shapedType.getElementType ());
4265+ if (elementVectorType)
4266+ elementVectorRank += elementVectorType.getRank ();
4267+ if (shapedType.getRank () < vectorType.getRank () - elementVectorRank)
4268+ return parser.emitError (typesLoc,
4269+ " expected a custom permutation_map when "
4270+ " rank(source) != rank(destination)" );
42664271 permMap = getTransferMinorIdentityMap (shapedType, vectorType);
4267- if (!permMap) {
4268- int64_t elementVectorRank = 0 ;
4269- VectorType elementVectorType =
4270- llvm::dyn_cast<VectorType>(shapedType.getElementType ());
4271- if (elementVectorType)
4272- elementVectorRank += elementVectorType.getRank ();
4273- if (shapedType.getRank () < vectorType.getRank () - elementVectorRank)
4274- return parser.emitError (typesLoc,
4275- " expected a custom permutation_map when source "
4276- " rank is less than required for vector rank" );
4277- }
42784272 result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
42794273 } else {
42804274 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue ();
@@ -4682,18 +4676,16 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
46824676 auto permMapAttr = result.attributes .get (permMapAttrName);
46834677 AffineMap permMap;
46844678 if (!permMapAttr) {
4679+ int64_t elementVectorRank = 0 ;
4680+ VectorType elementVectorType =
4681+ llvm::dyn_cast<VectorType>(shapedType.getElementType ());
4682+ if (elementVectorType)
4683+ elementVectorRank += elementVectorType.getRank ();
4684+ if (shapedType.getRank () < vectorType.getRank () - elementVectorRank)
4685+ return parser.emitError (typesLoc,
4686+ " expected a custom permutation_map when "
4687+ " rank(source) != rank(destination)" );
46854688 permMap = getTransferMinorIdentityMap (shapedType, vectorType);
4686- if (!permMap) {
4687- int64_t elementVectorRank = 0 ;
4688- VectorType elementVectorType =
4689- llvm::dyn_cast<VectorType>(shapedType.getElementType ());
4690- if (elementVectorType)
4691- elementVectorRank += elementVectorType.getRank ();
4692- if (shapedType.getRank () < vectorType.getRank () - elementVectorRank)
4693- return parser.emitError (typesLoc,
4694- " expected a custom permutation_map when result "
4695- " rank is less than required for vector rank" );
4696- }
46974689 result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
46984690 } else {
46994691 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue ();
0 commit comments