@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
445445 }
446446};
447447
448+ // / This pattern converts the CreateMaskOp to work on a
449+ // / linearized vector. The pattern currently
450+ // / supports only 2D masks with a unit outer dimension.
451+ // / Following,
452+ // / vector.create_mask %dims : vector<1x4xi1>
453+ // / is converted to:
454+ // / %out_1d = vector.create_mask %dims : vector<4xi1>
455+ // / %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
456+ struct LinearizeVectorCreateMask final
457+ : OpConversionPattern<vector::CreateMaskOp> {
458+ using OpConversionPattern::OpConversionPattern;
459+
460+ LinearizeVectorCreateMask (const TypeConverter &typeConverter,
461+ MLIRContext *context, PatternBenefit benefit = 1 )
462+ : OpConversionPattern(typeConverter, context, benefit) {}
463+
464+ LogicalResult
465+ matchAndRewrite (vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
466+ ConversionPatternRewriter &rewriter) const override {
467+ auto srcTy = createMaskOp.getType ();
468+ auto srcShape = srcTy.getShape ();
469+ if (srcShape.size () != 2 )
470+ return rewriter.notifyMatchFailure (createMaskOp,
471+ " only 2D mask is supported." );
472+
473+ if (srcShape[0 ] != 1 )
474+ return rewriter.notifyMatchFailure (
475+ createMaskOp, " only unit outer dimension is supported." );
476+
477+ auto dstTy = getTypeConverter ()->convertType (srcTy);
478+ if (!dstTy)
479+ return rewriter.notifyMatchFailure (createMaskOp, " cannot convert type." );
480+
481+ // Compare the first operand with 0. If it's less than or equal to 0,
482+ // create a zero mask, else strip the first operand and create a mask
483+ // using the second operand.
484+ auto firstOperand = adaptor.getOperands ().front ();
485+ auto zero =
486+ rewriter.create <mlir::arith::ConstantIndexOp>(createMaskOp.getLoc (), 0 );
487+ auto isZeroOrNegative = rewriter.create <mlir::arith::CmpIOp>(
488+ createMaskOp.getLoc (), mlir::arith::CmpIPredicate::sle, firstOperand,
489+ zero);
490+ auto isZeroOrNegativeSplat = rewriter.create <mlir::vector::SplatOp>(
491+ createMaskOp.getLoc (), dstTy, isZeroOrNegative);
492+
493+ // Use a select operation to choose between the masks.
494+ auto zeroMask = rewriter.create <mlir::arith::ConstantOp>(
495+ createMaskOp.getLoc (), dstTy, rewriter.getZeroAttr (dstTy));
496+ auto newMask = rewriter.create <mlir::vector::CreateMaskOp>(
497+ createMaskOp.getLoc (), dstTy, adaptor.getOperands ().back ());
498+ auto result = rewriter.create <mlir::arith::SelectOp>(
499+ createMaskOp.getLoc (), isZeroOrNegativeSplat, zeroMask, newMask);
500+
501+ rewriter.replaceOp (createMaskOp, result.getResult ());
502+ return success ();
503+ }
504+ };
505+
448506} // namespace
449507
450508// / Return true if the operation `op` does not support scalable vectors and
@@ -530,9 +588,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
530588void mlir::vector::populateVectorLinearizeBasePatterns (
531589 const TypeConverter &typeConverter, const ConversionTarget &target,
532590 RewritePatternSet &patterns) {
533- patterns.add <LinearizeConstantLike, LinearizeVectorizable,
534- LinearizeVectorBitCast, LinearizeVectorSplat>(
535- typeConverter, patterns.getContext ());
591+ patterns
592+ .add <LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
593+ LinearizeVectorSplat, LinearizeVectorCreateMask>(
594+ typeConverter, patterns.getContext ());
536595}
537596
538597void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments