@@ -151,29 +151,32 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
151151 return false ;
152152}
153153
154- // / Returns the number of dimensions of the `shapedType` that participate in the
155- // / vector transfer, effectively the rank of the vector dimensions within the
156- // / `shapedType`. This is calculated by taking the rank of the `vectorType`
157- // / being transferred and subtracting the rank of the `shapedType`'s element
158- // / type if it's also a vector.
154+ // / Returns the effective rank of the vector to read/write for Xfer Ops
159155// /
160- // / This is used to determine the number of minor dimensions for identity maps
161- // / in vector transfers.
156+ // / When the element type of the shaped type is _a scalar_, this will simply
157+ // / return the rank of the vector ( the result for xfer_read or the value to
158+ // / store for xfer_write).
162159// /
163- // / For example, given a transfer operation involving `shapedType` and
164- // / `vectorType`:
160+ // / When the element type of the base shaped type is _a vector_, returns the
161+ // / difference between the original vector type and the element type of the
162+ // / shaped type.
165163// /
164+ // / EXAMPLE 1 (element type is _a scalar_):
166165// / - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167166// / - shapedType.getElementType() = f32 (rank 0)
168167// / - vectorType.getRank() = 2
169168// / - Result = 2 - 0 = 2
170169// /
170+ // / EXAMPLE 2 (element type is _a vector_):
171171// / - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172172// / - shapedType.getElementType() = vector<20xf32> (rank 1)
173173// / - vectorType.getRank() = 1
174174// / - Result = 1 - 1 = 0
175- static unsigned getRealVectorRank (ShapedType shapedType,
176- VectorType vectorType) {
175+ // /
176+ // / This is used to determine the number of minor dimensions for identity maps
177+ // / in vector transfer Ops.
178+ static unsigned getEffectiveVectorRankForXferOp (ShapedType shapedType,
179+ VectorType vectorType) {
177180 unsigned elementVectorRank = 0 ;
178181 VectorType elementVectorType =
179182 llvm::dyn_cast<VectorType>(shapedType.getElementType ());
@@ -192,7 +195,8 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
192195 /* numDims=*/ 0 , /* numSymbols=*/ 0 ,
193196 getAffineConstantExpr (0 , shapedType.getContext ()));
194197 return AffineMap::getMinorIdentityMap (
195- shapedType.getRank (), getRealVectorRank (shapedType, vectorType),
198+ shapedType.getRank (),
199+ getEffectiveVectorRankForXferOp (shapedType, vectorType),
196200 shapedType.getContext ());
197201}
198202
@@ -4261,7 +4265,8 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
42614265 Attribute permMapAttr = result.attributes .get (permMapAttrName);
42624266 AffineMap permMap;
42634267 if (!permMapAttr) {
4264- if (shapedType.getRank () < getRealVectorRank (shapedType, vectorType))
4268+ if (shapedType.getRank () <
4269+ getEffectiveVectorRankForXferOp (shapedType, vectorType))
42654270 return parser.emitError (typesLoc,
42664271 " expected a custom permutation_map when "
42674272 " rank(source) != rank(destination)" );
@@ -4680,7 +4685,8 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
46804685 auto permMapAttr = result.attributes .get (permMapAttrName);
46814686 AffineMap permMap;
46824687 if (!permMapAttr) {
4683- if (shapedType.getRank () < getRealVectorRank (shapedType, vectorType))
4688+ if (shapedType.getRank () <
4689+ getEffectiveVectorRankForXferOp (shapedType, vectorType))
46844690 return parser.emitError (typesLoc,
46854691 " expected a custom permutation_map when "
46864692 " rank(source) != rank(destination)" );
0 commit comments