@@ -566,6 +566,64 @@ struct LinearizeVectorSplat final
566566 }
567567};
568568
569+ // / This pattern converts the CreateMaskOp to work on a linearized vector.
570+ // / It currently supports only 2D masks with a unit outer dimension.
571+ // / Following,
572+ // / vector.create_mask %arg0, %arg1 : vector<1x4xi1>
573+ // / is converted to:
574+ // / %zero = arith.constant 0 : index
575+ // / %cmpi = arith.cmpi sgt, %arg0, %zero : index
576+ // / %index = arith.index_cast %cmpi : i1 to index
577+ // / %mul = arith.andi %index, %arg1 : index
578+ // / %mask = vector.create_mask %mul : vector<4xi1>
579+ // / %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
580+ struct LinearizeVectorCreateMask final
581+ : OpConversionPattern<vector::CreateMaskOp> {
582+ using OpConversionPattern::OpConversionPattern;
583+
584+ LinearizeVectorCreateMask (const TypeConverter &typeConverter,
585+ MLIRContext *context, PatternBenefit benefit = 1 )
586+ : OpConversionPattern(typeConverter, context, benefit) {}
587+
588+ LogicalResult
589+ matchAndRewrite (vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
590+ ConversionPatternRewriter &rewriter) const override {
591+ Location loc = createMaskOp.getLoc ();
592+ VectorType srcTy = createMaskOp.getType ();
593+ auto srcShape = srcTy.getShape ();
594+ if (srcShape.size () != 2 )
595+ return rewriter.notifyMatchFailure (createMaskOp,
596+ " only 2D mask is supported." );
597+
598+ if (srcShape[0 ] != 1 )
599+ return rewriter.notifyMatchFailure (
600+ createMaskOp, " only unit outer dimension is supported." );
601+
602+ auto dstTy = getTypeConverter ()->convertType (srcTy);
603+ if (!dstTy)
604+ return rewriter.notifyMatchFailure (createMaskOp, " cannot convert type." );
605+
606+ // Compare the first operand with 0. If it is greater than 0, the
607+ // corresponding mask element is set to true, otherwise false.
608+ // The result of the comparison is then multiplied with
609+ // the second operand of create_mask to get the 1D mask.
610+ auto firstOperand = adaptor.getOperands ().front ();
611+ auto zero = rewriter.create <mlir::arith::ConstantIndexOp>(loc, 0 );
612+ auto isNonZero = rewriter.createOrFold <mlir::arith::CmpIOp>(
613+ loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
614+ auto isNonZeroIndex = rewriter.createOrFold <mlir::arith::IndexCastOp>(
615+ loc, rewriter.getIndexType (), isNonZero);
616+ auto secondOperand = adaptor.getOperands ().back ();
617+ auto maskSize = rewriter.createOrFold <mlir::arith::AndIOp>(
618+ loc, rewriter.getIndexType (), isNonZeroIndex, secondOperand);
619+
620+ auto newMask =
621+ rewriter.create <mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
622+ rewriter.replaceOp (createMaskOp, newMask);
623+ return success ();
624+ }
625+ };
626+
569627} // namespace
570628
571629// / Return true if the operation `op` does not support scalable vectors and
@@ -651,9 +709,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
651709void mlir::vector::populateVectorLinearizeBasePatterns (
652710 const TypeConverter &typeConverter, const ConversionTarget &target,
653711 RewritePatternSet &patterns) {
654- patterns.add <LinearizeConstantLike, LinearizeVectorizable,
655- LinearizeVectorBitCast, LinearizeVectorSplat>(
656- typeConverter, patterns.getContext ());
712+ patterns
713+ .add <LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
714+ LinearizeVectorSplat, LinearizeVectorCreateMask>(
715+ typeConverter, patterns.getContext ());
657716}
658717
659718void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments