1010//
1111// ===----------------------------------------------------------------------===//
1212
13+ #include " mlir/Dialect/Arith/IR/Arith.h"
1314#include " mlir/Dialect/UB/IR/UBOps.h"
1415#include " mlir/Dialect/Vector/IR/VectorOps.h"
1516#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2021#include " mlir/IR/TypeUtilities.h"
2122#include " mlir/Transforms/DialectConversion.h"
2223#include " llvm/ADT/ArrayRef.h"
24+ #include < algorithm>
2325#include < cstdint>
2426#include < numeric>
2527#include < optional>
@@ -469,6 +471,11 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
469471 return containsScalableResult;
470472}
471473
474+ static bool isLinearizableCreateMaskOp (vector::CreateMaskOp createMaskOp) {
475+ auto shape = createMaskOp.getType ().getShape ();
476+ return llvm::count_if (shape, [](int64_t dim) { return dim > 1 ; }) <= 1 ;
477+ }
478+
472479static bool isNotLinearizable (Operation *op) {
473480
474481 // Only ops that are in the vector dialect, are ConstantLike, or
@@ -485,6 +492,12 @@ static bool isNotLinearizable(Operation *op) {
485492 if (isNotLinearizableBecauseScalable (op))
486493 return true ;
487494
495+ if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(op)) {
496+ if (!isLinearizableCreateMaskOp (createMaskOp)) {
497+ return true ;
498+ }
499+ }
500+
488501 return false ;
489502}
490503
@@ -527,12 +540,54 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
527540 });
528541}
529542
543+ // / Linearize vector.create_mask with at most 1 non-unit dimension. Example:
544+ // /
545+ // / ```
546+ // / %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
547+ // / ```
548+ // /
549+ // / becomes
550+ // /
551+ // / ```
552+ // / %0 = arith.muli %arg0, %arg1 : index
553+ // / %1 = arith.muli %0, %arg2 : index
554+ // / %2 = vector.create_mask %1: vector<16xi1>
555+ // / %3 = vector.shape_cast %2: vector<16xi1> to vector<1x16x1xi1>
556+ // / ```
557+ struct LinearizeVectorCreateMask final
558+ : OpConversionPattern<vector::CreateMaskOp> {
559+ using OpConversionPattern::OpConversionPattern;
560+
561+ LinearizeVectorCreateMask (const TypeConverter &typeConverter,
562+ MLIRContext *context, PatternBenefit benefit = 1 )
563+ : OpConversionPattern(typeConverter, context, benefit) {}
564+
565+ LogicalResult
566+ matchAndRewrite (vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
567+ ConversionPatternRewriter &rewriter) const override {
568+
569+ VectorType maskType = createMaskOp.getType ();
570+ assert (isLinearizableCreateMaskOp (createMaskOp));
571+
572+ Value product = adaptor.getOperands ().front ();
573+ for (unsigned i = 1 ; i < maskType.getRank (); ++i) {
574+ product = rewriter.create <mlir::arith::MulIOp>(
575+ createMaskOp.getLoc (), product, adaptor.getOperands ()[i]);
576+ }
577+ Type flatMaskType = getTypeConverter ()->convertType (maskType);
578+ auto newMask = rewriter.create <mlir::vector::CreateMaskOp>(
579+ createMaskOp.getLoc (), flatMaskType, product);
580+ rewriter.replaceOp (createMaskOp, newMask);
581+ return success ();
582+ }
583+ };
584+
530585void mlir::vector::populateVectorLinearizeBasePatterns (
531586 const TypeConverter &typeConverter, const ConversionTarget &target,
532587 RewritePatternSet &patterns) {
533588 patterns.add <LinearizeConstantLike, LinearizeVectorizable,
534- LinearizeVectorBitCast, LinearizeVectorSplat>(
535- typeConverter, patterns.getContext ());
589+ LinearizeVectorCreateMask, LinearizeVectorBitCast,
590+ LinearizeVectorSplat>( typeConverter, patterns.getContext ());
536591}
537592
538593void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments