@@ -151,13 +151,39 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
151151 return false ;
152152}
153153
154- AffineMap mlir::vector::getTransferMinorIdentityMap (ShapedType shapedType,
155- VectorType vectorType) {
156- int64_t elementVectorRank = 0 ;
154+ // / Returns the number of dimensions of the `shapedType` that participate in the
155+ // / vector transfer, effectively the rank of the vector dimensions within the
156+ // / `shapedType`. This is calculated by taking the rank of the `vectorType`
157+ // / being transferred and subtracting the rank of the `shapedType`'s element
158+ // / type if it's also a vector.
159+ // /
160+ // / This is used to determine the number of minor dimensions for identity maps
161+ // / in vector transfers.
162+ // /
163+ // / For example, given a transfer operation involving `shapedType` and
164+ // / `vectorType`:
165+ // /
166+ // / - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167+ // / - shapedType.getElementType() = f32 (rank 0)
168+ // / - vectorType.getRank() = 2
169+ // / - Result = 2 - 0 = 2
170+ // /
171+ // / - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172+ // / - shapedType.getElementType() = vector<20xf32> (rank 1)
173+ // / - vectorType.getRank() = 1
174+ // / - Result = 1 - 1 = 0
175+ static unsigned getRealVectorRank (ShapedType shapedType,
176+ VectorType vectorType) {
177+ unsigned elementVectorRank = 0 ;
157178 VectorType elementVectorType =
158179 llvm::dyn_cast<VectorType>(shapedType.getElementType ());
159180 if (elementVectorType)
160181 elementVectorRank += elementVectorType.getRank ();
182+ return vectorType.getRank () - elementVectorRank;
183+ }
184+
185+ AffineMap mlir::vector::getTransferMinorIdentityMap (ShapedType shapedType,
186+ VectorType vectorType) {
161187 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
162188 // TODO: replace once we have 0-d vectors.
163189 if (shapedType.getRank () == 0 &&
@@ -166,7 +192,7 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
166192 /* numDims=*/ 0 , /* numSymbols=*/ 0 ,
167193 getAffineConstantExpr (0 , shapedType.getContext ()));
168194 return AffineMap::getMinorIdentityMap (
169- shapedType.getRank (), vectorType. getRank () - elementVectorRank ,
195+ shapedType.getRank (), getRealVectorRank (shapedType, vectorType) ,
170196 shapedType.getContext ());
171197}
172198
@@ -4234,6 +4260,10 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
42344260 Attribute permMapAttr = result.attributes .get (permMapAttrName);
42354261 AffineMap permMap;
42364262 if (!permMapAttr) {
4263+ if (shapedType.getRank () < getRealVectorRank (shapedType, vectorType))
4264+ return parser.emitError (typesLoc,
4265+ " expected a custom permutation_map when "
4266+ " rank(source) != rank(destination)" );
42374267 permMap = getTransferMinorIdentityMap (shapedType, vectorType);
42384268 result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
42394269 } else {
@@ -4649,6 +4679,10 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
46494679 auto permMapAttr = result.attributes .get (permMapAttrName);
46504680 AffineMap permMap;
46514681 if (!permMapAttr) {
4682+ if (shapedType.getRank () < getRealVectorRank (shapedType, vectorType))
4683+ return parser.emitError (typesLoc,
4684+ " expected a custom permutation_map when "
4685+ " rank(source) != rank(destination)" );
46524686 permMap = getTransferMinorIdentityMap (shapedType, vectorType);
46534687 result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
46544688 } else {
0 commit comments