@@ -62,12 +62,10 @@ struct LinearizeConstantLike final
6262 if (op->getNumResults () != 1 )
6363 return rewriter.notifyMatchFailure (loc, " expected 1 result" );
6464
65- const TypeConverter &converter = *getTypeConverter ();
65+ const TypeConverter &typeConverter = *getTypeConverter ();
6666 auto resType =
67- converter.convertType <VectorType>(op->getResult (0 ).getType ());
68-
69- if (!resType)
70- return rewriter.notifyMatchFailure (loc, " can't convert return type" );
67+ typeConverter.convertType <VectorType>(op->getResult (0 ).getType ());
68+ assert (resType && " expected 1-D vector type" );
7169
7270 StringAttr attrName = rewriter.getStringAttr (" value" );
7371 Attribute value = op->getAttr (attrName);
@@ -80,7 +78,7 @@ struct LinearizeConstantLike final
8078 return failure ();
8179
8280 FailureOr<Operation *> convertResult =
83- convertOpResultTypes (op, /* operands=*/ {}, converter , rewriter);
81+ convertOpResultTypes (op, /* operands=*/ {}, typeConverter , rewriter);
8482 if (failed (convertResult))
8583 return failure ();
8684
@@ -244,14 +242,6 @@ struct LinearizeVectorShuffle final
244242 VectorType dstType =
245243 getTypeConverter ()->convertType <VectorType>(shuffleOp.getType ());
246244 assert (dstType && " vector type destination expected." );
247- // The assert is used because vector.shuffle does not support scalable
248- // vectors.
249- bool scalable = shuffleOp.getV1VectorType ().isScalable () ||
250- shuffleOp.getV2VectorType ().isScalable () ||
251- dstType.isScalable ();
252- if (scalable)
253- return rewriter.notifyMatchFailure (shuffleOp,
254- " scalable vectors are not supported." );
255245
256246 Value vec1 = adaptor.getV1 ();
257247 Value vec2 = adaptor.getV2 ();
@@ -270,7 +260,7 @@ struct LinearizeVectorShuffle final
270260 }
271261
272262 // For each value in the mask, we generate the indices of the source vectors
273- // that needs to be shuffled to the destination vector. If shuffleSliceLen >
263+ // that need to be shuffled to the destination vector. If shuffleSliceLen >
274264 // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
275265 // elements) instead of scalars.
276266 ArrayRef<int64_t > mask = shuffleOp.getMask ();
@@ -309,14 +299,7 @@ struct LinearizeVectorExtract final
309299 matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
310300 ConversionPatternRewriter &rewriter) const override {
311301 Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
312- if (!dstTy)
313- return rewriter.notifyMatchFailure (extractOp,
314- " expected n-D vector type." );
315-
316- if (extractOp.getVector ().getType ().isScalable () ||
317- cast<VectorType>(dstTy).isScalable ())
318- return rewriter.notifyMatchFailure (extractOp,
319- " scalable vectors are not supported." );
302+ assert (dstTy && " expected 1-D vector type" );
320303
321304 // Dynamic position is not supported.
322305 if (extractOp.hasDynamicPosition ())
@@ -367,9 +350,6 @@ struct LinearizeVectorInsert final
367350 VectorType dstTy = getTypeConverter ()->convertType <VectorType>(
368351 insertOp.getDestVectorType ());
369352 assert (dstTy && " vector type destination expected." );
370- if (insertOp.getDestVectorType ().isScalable () || dstTy.isScalable ())
371- return rewriter.notifyMatchFailure (insertOp,
372- " scalable vectors are not supported." );
373353
374354 // dynamic position is not supported
375355 if (insertOp.hasDynamicPosition ())
@@ -436,11 +416,8 @@ struct LinearizeVectorBitCast final
436416 LogicalResult
437417 matchAndRewrite (vector::BitCastOp castOp, OpAdaptor adaptor,
438418 ConversionPatternRewriter &rewriter) const override {
439- Location loc = castOp.getLoc ();
440419 auto resType = getTypeConverter ()->convertType (castOp.getType ());
441- if (!resType)
442- return rewriter.notifyMatchFailure (loc, " can't convert return type." );
443-
420+ assert (resType && " expected 1-D vector type" );
444421 rewriter.replaceOpWithNewOp <vector::BitCastOp>(castOp, resType,
445422 adaptor.getSource ());
446423 return mlir::success ();
@@ -449,56 +426,15 @@ struct LinearizeVectorBitCast final
449426
450427} // namespace
451428
452- // / If `type` is VectorType with trailing dimension of (bit) size greater than
453- // / or equal to `targetBitWidth`, its defining op is considered legal.
454- static bool legalBecauseOfBitwidth (Type type, unsigned targetBitWidth) {
455-
456- VectorType vecType = dyn_cast<VectorType>(type);
457-
458- if (!vecType)
459- return true ;
460-
461- // The width of the type 'index' is unbounded (and therefore potentially above
462- // the target width).
463- if (vecType.getElementType ().isIndex ())
464- return true ;
465-
466- unsigned finalDimSize =
467- vecType.getRank () == 0 ? 0 : vecType.getShape ().back ();
468-
469- unsigned trailingVecDimBitWidth =
470- finalDimSize * vecType.getElementTypeBitWidth ();
471-
472- return trailingVecDimBitWidth >= targetBitWidth;
473- }
474-
475- static SmallVector<std::pair<Type, unsigned >>
476- getChecksForBitwidth (Operation *op, unsigned targetBitWidth) {
477-
478- if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
479- auto w = targetBitWidth < std::numeric_limits<unsigned >::max ()
480- ? targetBitWidth + 1
481- : targetBitWidth;
482- return {{insertOp.getValueToStoreType (), w}};
483- }
484- auto resultTypes = op->getResultTypes ();
485- SmallVector<std::pair<Type, unsigned >> resultsWithBitWidth;
486- resultsWithBitWidth.reserve (resultTypes.size ());
487- for (Type type : resultTypes) {
488- resultsWithBitWidth.push_back ({type, targetBitWidth});
489- }
490- return resultsWithBitWidth;
491- }
492-
493429// / Return true if the operation `op` does not support scalable vectors and
494- // / has at least 1 scalable vector result.
495- static bool legalBecauseScalable (Operation *op) {
496-
497- bool scalableSupported = op-> hasTrait <OpTrait::ConstantLike>() ||
498- op-> hasTrait <OpTrait::Vectorizable>() ||
499- isa<vector::BitCastOp>(op);
500-
501- if (scalableSupported )
430+ // / has at least 1 scalable vector result. These ops should all eventually
431+ // / support scalable vectors, and this function should be removed.
432+ static bool isNotLinearizableBecauseScalable (Operation *op) {
433+
434+ bool unsupported =
435+ isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
436+ op);
437+ if (!unsupported )
502438 return false ;
503439
504440 // Check if any of the results is a scalable vector type.
@@ -512,73 +448,74 @@ static bool legalBecauseScalable(Operation *op) {
512448 return containsScalableResult;
513449}
514450
515- static bool dynamicallyLegal (Operation *op, unsigned targetBitWidth ) {
451+ static bool isNotLinearizable (Operation *op) {
516452
517453 // Only ops that are in the vector dialect, are ConstantLike, or
518- // are Vectorizable might be linearized currently, so legalize the others.
519- bool opIsVectorDialect = op->getDialect ()->getNamespace () ==
520- vector::VectorDialect::getDialectNamespace ();
521- if (!opIsVectorDialect && !op->hasTrait <OpTrait::ConstantLike>() &&
522- !op->hasTrait <OpTrait::Vectorizable>())
454+ // are Vectorizable might be linearized currently.
455+ StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace ();
456+ StringRef opDialect = op->getDialect ()->getNamespace ();
457+ bool unsupported = (opDialect != vectorDialect) &&
458+ !op->hasTrait <OpTrait::ConstantLike>() &&
459+ !op->hasTrait <OpTrait::Vectorizable>();
460+ if (unsupported)
523461 return true ;
524462
525- // Some ops will not be linearized if they have scalable vector results .
526- if (legalBecauseScalable (op))
463+ // Some ops currently don't support scalable vectors .
464+ if (isNotLinearizableBecauseScalable (op))
527465 return true ;
528466
529- // Check on bitwidths.
530- auto typesToCheck = getChecksForBitwidth (op, targetBitWidth);
531- return std::any_of (typesToCheck.begin (), typesToCheck.end (),
532- [&](std::pair<Type, unsigned > typeWidth) {
533- return legalBecauseOfBitwidth (typeWidth.first ,
534- typeWidth.second );
535- });
467+ return false ;
536468}
537469
538- void mlir::vector::populateVectorLinearizeBitWidthTargetAndConverter (
539- TypeConverter &typeConverter, ConversionTarget &target,
540- unsigned targetBitWidth) {
470+ void mlir::vector::populateForVectorLinearize (TypeConverter &typeConverter,
471+ ConversionTarget &target) {
541472
542- typeConverter.addConversion ([](VectorType type) -> std::optional<Type> {
543- if (!isLinearizableVector (type))
473+ auto convertType = [](Type type) -> std::optional<Type> {
474+ VectorType vectorType = dyn_cast<VectorType>(type);
475+ if (!vectorType || !isLinearizableVector (vectorType))
544476 return type;
545477
546- return VectorType::get (type.getNumElements (), type.getElementType (),
547- type.isScalable ());
548- });
478+ VectorType linearizedType =
479+ VectorType::get (vectorType.getNumElements (),
480+ vectorType.getElementType (), vectorType.isScalable ());
481+ return linearizedType;
482+ };
483+ typeConverter.addConversion (convertType);
549484
550485 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
551486 Location loc) -> Value {
552- if (inputs.size () != 1 || !isa<VectorType>(inputs.front ().getType ()) ||
553- !isa<VectorType>(type))
487+ if (inputs.size () != 1 )
554488 return nullptr ;
555- return builder.create <vector::ShapeCastOp>(loc, type, inputs.front ());
556- };
557489
490+ Value value = inputs.front ();
491+ if (!isa<VectorType>(type) || !isa<VectorType>(value.getType ()))
492+ return nullptr ;
493+
494+ return builder.create <vector::ShapeCastOp>(loc, type, value);
495+ };
558496 typeConverter.addSourceMaterialization (materializeCast);
559497 typeConverter.addTargetMaterialization (materializeCast);
560498
561499 target.markUnknownOpDynamicallyLegal (
562500 [=](Operation *op) -> std::optional<bool > {
563- bool isDynamicallyLegal = dynamicallyLegal (op, targetBitWidth);
564- if (isDynamicallyLegal)
501+ if (isNotLinearizable (op))
565502 return true ;
566-
567- bool shapeUnchanged = typeConverter. isLegal (op);
568- return shapeUnchanged ;
503+ // This will return true if, for all operand and result types `t`,
504+ // convertType(t) = t. This is true if there are no rank>=2 vectors.
505+ return typeConverter. isLegal (op) ;
569506 });
570507}
571508
572509void mlir::vector::populateVectorLinearizeBasePatterns (
573- const TypeConverter &typeConverter, RewritePatternSet &patterns ,
574- const ConversionTarget &target ) {
510+ const TypeConverter &typeConverter, const ConversionTarget &target ,
511+ RewritePatternSet &patterns ) {
575512 patterns.add <LinearizeConstantLike, LinearizeVectorizable,
576513 LinearizeVectorBitCast>(typeConverter, patterns.getContext ());
577514}
578515
579516void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
580- const TypeConverter &typeConverter, RewritePatternSet &patterns ,
581- const ConversionTarget &target ) {
517+ const TypeConverter &typeConverter, const ConversionTarget &target ,
518+ RewritePatternSet &patterns ) {
582519 patterns.add <LinearizeVectorShuffle, LinearizeVectorExtract,
583520 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
584521 typeConverter, patterns.getContext ());
0 commit comments