2525
2626using namespace mlir ;
2727
28- static bool isLessThanTargetBitWidth (Operation *op, unsigned targetBitWidth) {
28+ static bool isLessThanTargetBitWidth (Operation *op, unsigned indexBitWidth,
29+ unsigned targetBitWidth) {
2930 auto resultTypes = op->getResultTypes ();
3031 for (auto resType : resultTypes) {
3132 VectorType vecType = dyn_cast<VectorType>(resType);
32- // Reject index since getElementTypeBitWidth will abort for Index types.
33- if (!vecType || vecType.getElementType ().isIndex ())
33+ if (!vecType)
34+ return false ;
35+ bool isIndexTy = vecType.getElementType ().isIndex ();
36+ // Reject index if `indexBitWidth` is not supplied.
37+ if (isIndexTy && indexBitWidth == 0 )
3438 return false ;
3539 // There are no dimension to fold if it is a 0-D vector.
3640 if (vecType.getRank () == 0 )
3741 return false ;
3842 unsigned trailingVecDimBitWidth =
39- vecType.getShape ().back () * vecType.getElementTypeBitWidth ();
43+ vecType.getShape ().back () *
44+ (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth ());
4045 if (trailingVecDimBitWidth >= targetBitWidth)
4146 return false ;
4247 }
4348 return true ;
4449}
4550
46- static bool isLessThanOrEqualTargetBitWidth (Type t, unsigned targetBitWidth) {
51+ static bool isLessThanOrEqualTargetBitWidth (Type t, unsigned indexBitWidth,
52+ unsigned targetBitWidth) {
4753 VectorType vecType = dyn_cast<VectorType>(t);
48- // Reject index since getElementTypeBitWidth will abort for Index types.
49- if (!vecType || vecType.getElementType ().isIndex ())
54+ if (!vecType)
55+ return false ;
56+ bool isIndexTy = vecType.getElementType ().isIndex ();
57+ // Reject index if `indexBitWidth` is not supplied.
58+ if (isIndexTy && indexBitWidth == 0 )
5059 return false ;
5160 // There are no dimension to fold if it is a 0-D vector.
5261 if (vecType.getRank () == 0 )
5362 return false ;
5463 unsigned trailingVecDimBitWidth =
55- vecType.getShape ().back () * vecType.getElementTypeBitWidth ();
64+ vecType.getShape ().back () *
65+ (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth ());
5666 return trailingVecDimBitWidth <= targetBitWidth;
5767}
5868
@@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
6171 using OpConversionPattern::OpConversionPattern;
6272 LinearizeConstant (
6373 const TypeConverter &typeConverter, MLIRContext *context,
74+ unsigned indexBitWidth = 0 ,
6475 unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
6576 PatternBenefit benefit = 1 )
6677 : OpConversionPattern(typeConverter, context, benefit),
67- targetVectorBitWidth (targetVectBitWidth) {}
78+ indexBitWidth (indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
79+ }
6880 LogicalResult
6981 matchAndRewrite (arith::ConstantOp constOp, OpAdaptor adaptor,
7082 ConversionPatternRewriter &rewriter) const override {
@@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
7991
8092 if (!resType)
8193 return rewriter.notifyMatchFailure (loc, " can't convert return type" );
82- if (!isLessThanTargetBitWidth (constOp, targetVectorBitWidth))
94+ if (!isLessThanTargetBitWidth (constOp, indexBitWidth, targetVectorBitWidth))
8395 return rewriter.notifyMatchFailure (
8496 loc, " Can't flatten since targetBitWidth <= OpSize" );
8597 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue ());
@@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
93105 }
94106
95107private:
108+ unsigned indexBitWidth;
96109 unsigned targetVectorBitWidth;
97110};
98111
@@ -103,14 +116,16 @@ struct LinearizeVectorizable final
103116public:
104117 LinearizeVectorizable (
105118 const TypeConverter &typeConverter, MLIRContext *context,
119+ unsigned indexBitWidth = 0 ,
106120 unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
107121 PatternBenefit benefit = 1 )
108122 : OpTraitConversionPattern(typeConverter, context, benefit),
109- targetVectorBitWidth (targetVectBitWidth) {}
123+ indexBitWidth (indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
124+ }
110125 LogicalResult
111126 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
112127 ConversionPatternRewriter &rewriter) const override {
113- if (!isLessThanTargetBitWidth (op, targetVectorBitWidth))
128+ if (!isLessThanTargetBitWidth (op, indexBitWidth, targetVectorBitWidth))
114129 return rewriter.notifyMatchFailure (
115130 op->getLoc (), " Can't flatten since targetBitWidth <= OpSize" );
116131 FailureOr<Operation *> newOp =
@@ -123,6 +138,7 @@ struct LinearizeVectorizable final
123138 }
124139
125140private:
141+ unsigned indexBitWidth;
126142 unsigned targetVectorBitWidth;
127143};
128144
@@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final
142158 using OpConversionPattern::OpConversionPattern;
143159 LinearizeVectorExtractStridedSlice (
144160 const TypeConverter &typeConverter, MLIRContext *context,
161+ unsigned indexBitWidth = 0 ,
145162 unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
146163 PatternBenefit benefit = 1 )
147164 : OpConversionPattern(typeConverter, context, benefit),
148- targetVectorBitWidth (targetVectBitWidth) {}
165+ indexBitWidth (indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
166+ }
149167
150168 LogicalResult
151169 matchAndRewrite (vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final
156174 if (extractOp.getVector ().getType ().isScalable () || dstType.isScalable ())
157175 return rewriter.notifyMatchFailure (extractOp,
158176 " scalable vectors are not supported." );
159- if (!isLessThanTargetBitWidth (extractOp, targetVectorBitWidth))
177+ if (!isLessThanTargetBitWidth (extractOp, indexBitWidth,
178+ targetVectorBitWidth))
160179 return rewriter.notifyMatchFailure (
161180 extractOp, " Can't flatten since targetBitWidth <= OpSize" );
162181
@@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final
237256 }
238257
239258private:
259+ unsigned indexBitWidth;
240260 unsigned targetVectorBitWidth;
241261};
242262
@@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final
256276 using OpConversionPattern::OpConversionPattern;
257277 LinearizeVectorShuffle (
258278 const TypeConverter &typeConverter, MLIRContext *context,
279+ unsigned indexBitWidth = 0 ,
259280 unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
260281 PatternBenefit benefit = 1 )
261282 : OpConversionPattern(typeConverter, context, benefit),
262- targetVectorBitWidth (targetVectBitWidth) {}
283+ indexBitWidth (indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
284+ }
263285
264286 LogicalResult
265287 matchAndRewrite (vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final
273295 shuffleOp.getV2VectorType ().isScalable () ||
274296 dstType.isScalable ()) &&
275297 " scalable vectors are not supported." );
276- if (!isLessThanTargetBitWidth (shuffleOp, targetVectorBitWidth))
298+ if (!isLessThanTargetBitWidth (shuffleOp, indexBitWidth,
299+ targetVectorBitWidth))
277300 return rewriter.notifyMatchFailure (
278301 shuffleOp, " Can't flatten since targetBitWidth <= OpSize" );
279302
@@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final
312335 }
313336
314337private:
338+ unsigned indexBitWidth;
315339 unsigned targetVectorBitWidth;
316340};
317341
@@ -329,10 +353,12 @@ struct LinearizeVectorExtract final
329353 using OpConversionPattern::OpConversionPattern;
330354 LinearizeVectorExtract (
331355 const TypeConverter &typeConverter, MLIRContext *context,
356+ unsigned indexBitWidth = 0 ,
332357 unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
333358 PatternBenefit benefit = 1 )
334359 : OpConversionPattern(typeConverter, context, benefit),
335- targetVectorBitWidth (targetVectBitWidth) {}
360+ indexBitWidth (indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
361+ }
336362 LogicalResult
337363 matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
338364 ConversionPatternRewriter &rewriter) const override {
@@ -345,7 +371,8 @@ struct LinearizeVectorExtract final
345371 cast<VectorType>(dstTy).isScalable ())
346372 return rewriter.notifyMatchFailure (extractOp,
347373 " scalable vectors are not supported." );
348- if (!isLessThanTargetBitWidth (extractOp, targetVectorBitWidth))
374+ if (!isLessThanTargetBitWidth (extractOp, indexBitWidth,
375+ targetVectorBitWidth))
349376 return rewriter.notifyMatchFailure (
350377 extractOp, " Can't flatten since targetBitWidth <= OpSize" );
351378
@@ -374,6 +401,7 @@ struct LinearizeVectorExtract final
374401 }
375402
376403private:
404+ unsigned indexBitWidth;
377405 unsigned targetVectorBitWidth;
378406};
379407
@@ -392,10 +420,12 @@ struct LinearizeVectorInsert final
392420 using OpConversionPattern::OpConversionPattern;
393421 LinearizeVectorInsert (
394422 const TypeConverter &typeConverter, MLIRContext *context,
423+ unsigned indexBitWidth = 0 ,
395424 unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
396425 PatternBenefit benefit = 1 )
397426 : OpConversionPattern(typeConverter, context, benefit),
398- targetVectorBitWidth (targetVectBitWidth) {}
427+ indexBitWidth (indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
428+ }
399429 LogicalResult
400430 matchAndRewrite (vector::InsertOp insertOp, OpAdaptor adaptor,
401431 ConversionPatternRewriter &rewriter) const override {
@@ -407,7 +437,7 @@ struct LinearizeVectorInsert final
407437 " scalable vectors are not supported." );
408438
409439 if (!isLessThanOrEqualTargetBitWidth (insertOp.getSourceType (),
410- targetVectorBitWidth))
440+ indexBitWidth, targetVectorBitWidth))
411441 return rewriter.notifyMatchFailure (
412442 insertOp, " Can't flatten since targetBitWidth < OpSize" );
413443
@@ -457,13 +487,14 @@ struct LinearizeVectorInsert final
457487 }
458488
459489private:
490+ unsigned indexBitWidth;
460491 unsigned targetVectorBitWidth;
461492};
462493} // namespace
463494
464495void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality (
465496 TypeConverter &typeConverter, RewritePatternSet &patterns,
466- ConversionTarget &target, unsigned targetBitWidth) {
497+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) {
467498
468499 typeConverter.addConversion ([](VectorType type) -> std::optional<Type> {
469500 if (!isLinearizableVector (type))
@@ -488,29 +519,31 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
488519 [=](Operation *op) -> std::optional<bool > {
489520 if ((isa<arith::ConstantOp>(op) ||
490521 op->hasTrait <OpTrait::Vectorizable>())) {
491- return (isLessThanTargetBitWidth (op, targetBitWidth)
522+ return (isLessThanTargetBitWidth (op, indexBitWidth, targetBitWidth)
492523 ? typeConverter.isLegal (op)
493524 : true );
494525 }
495526 return std::nullopt ;
496527 });
497528
498529 patterns.add <LinearizeConstant, LinearizeVectorizable>(
499- typeConverter, patterns.getContext (), targetBitWidth);
530+ typeConverter, patterns.getContext (), indexBitWidth, targetBitWidth);
500531}
501532
502533void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
503534 const TypeConverter &typeConverter, RewritePatternSet &patterns,
504- ConversionTarget &target, unsigned int targetBitWidth) {
535+ ConversionTarget &target, unsigned indexBitWidth,
536+ unsigned int targetBitWidth) {
505537 target.addDynamicallyLegalOp <vector::ShuffleOp>(
506538 [=](vector::ShuffleOp shuffleOp) -> bool {
507- return isLessThanTargetBitWidth (shuffleOp, targetBitWidth)
539+ return isLessThanTargetBitWidth (shuffleOp, indexBitWidth,
540+ targetBitWidth)
508541 ? (typeConverter.isLegal (shuffleOp) &&
509542 cast<mlir::VectorType>(shuffleOp.getResult ().getType ())
510543 .getRank () == 1 )
511544 : true ;
512545 });
513546 patterns.add <LinearizeVectorShuffle, LinearizeVectorExtract,
514547 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
515- typeConverter, patterns.getContext (), targetBitWidth);
548+ typeConverter, patterns.getContext (), indexBitWidth, targetBitWidth);
516549}
0 commit comments