2121#include " mlir/IR/TypeUtilities.h"
2222#include " mlir/Transforms/DialectConversion.h"
2323#include " llvm/ADT/ArrayRef.h"
24+ #include " llvm/ADT/STLExtras.h"
2425#include < algorithm>
2526#include < cstdint>
2627#include < numeric>
@@ -471,9 +472,12 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
471472 return containsScalableResult;
472473}
473474
474- static bool isLinearizableCreateMaskOp (vector::CreateMaskOp createMaskOp) {
475- auto shape = createMaskOp.getType ().getShape ();
476- return llvm::count_if (shape, [](int64_t dim) { return dim > 1 ; }) <= 1 ;
475+ static bool
476+ isCreateMaskWithAtMostOneNonUnit (vector::CreateMaskOp createMaskOp) {
477+ ArrayRef<int64_t > shape = createMaskOp.getType ().getShape ();
478+ bool multipleNonUnitDim =
479+ llvm::count_if (shape, [](int64_t dim) { return dim > 1 ; }) > 1 ;
480+ return !multipleNonUnitDim;
477481}
478482
479483static bool isNotLinearizable (Operation *op) {
@@ -493,7 +497,7 @@ static bool isNotLinearizable(Operation *op) {
493497 return true ;
494498
495499 if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(op)) {
496- if (!isLinearizableCreateMaskOp (createMaskOp)) {
500+ if (!isCreateMaskWithAtMostOneNonUnit (createMaskOp)) {
497501 return true ;
498502 }
499503 }
@@ -540,7 +544,8 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
540544 });
541545}
542546
543- // / Linearize vector.create_mask with at most 1 non-unit dimension. Example:
547+ // / Linearize a vector.create_mask that has at most 1 non-unit dimension.
548+ // / Example:
544549// /
545550// / ```
546551// / %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
@@ -549,11 +554,30 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
549554// / becomes
550555// /
551556// / ```
552- // / %0 = arith.muli %arg0, %arg1 : index
553- // / %1 = arith.muli %0, %arg2 : index
554- // / %2 = vector.create_mask %1: vector<16xi1>
557+ // / [...]
558+ // / %2 = vector.create_mask %prod: vector<16xi1>
555559// / %3 = vector.shape_cast %2: vector<16xi1> to vector<1x16x1xi1>
556560// / ```
561+ // /
562+ // / where %prod above the product of the (clamped) dimension-wise masking ranges
563+ // / %arg0, %arg1, and %arg2.
564+ // /
565+ // / This is equivalent to choosing the rank-1 masking range as:
566+ // / 1) %arg1 if %arg0 and %arg2 are stricty positive
567+ // / 2) 0 if either %arg0 or %arg2 are 0 or negative.
568+ // /
569+ // / Specifically, %prod is obtained as
570+ // /
571+ // / ```
572+ // / %true = arith.constant true
573+ // / %zero = arith.constant 0 : index
574+ // / %0 = arith.cmpi sgt, %arg0, %zero : index
575+ // / %1 = arith.muli %true, %0 : i1
576+ // / %2 = arith.cmpi sgt, %arg2, %zero : index
577+ // / %3 = arith.muli %1, %2 : i1
578+ // / %4 = arith.index_cast %3 : i1 to index
579+ // / %prod = arith.muli %4, %arg1 : index
580+ // / ```
557581struct LinearizeVectorCreateMask final
558582 : OpConversionPattern<vector::CreateMaskOp> {
559583 using OpConversionPattern::OpConversionPattern;
@@ -563,21 +587,44 @@ struct LinearizeVectorCreateMask final
563587 : OpConversionPattern(typeConverter, context, benefit) {}
564588
565589 LogicalResult
566- matchAndRewrite (vector::CreateMaskOp createMaskOp , OpAdaptor adaptor,
590+ matchAndRewrite (vector::CreateMaskOp maskOp , OpAdaptor adaptor,
567591 ConversionPatternRewriter &rewriter) const override {
568592
569- VectorType maskType = createMaskOp.getType ();
570- assert (isLinearizableCreateMaskOp (createMaskOp));
593+ VectorType type = maskOp.getType ();
594+ assert (isCreateMaskWithAtMostOneNonUnit (maskOp) &&
595+ " expected linearizable create_mask" );
596+
597+ Location loc = maskOp.getLoc ();
598+
599+ // First, get the product of (clamped) mask sizes in the unit-dimensions.
600+ Value prod = rewriter.create <arith::ConstantIntOp>(loc, 1 , 1 );
601+ Value zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
602+ int nonUnitDim = -1 ;
603+ for (unsigned i = 0 ; i < type.getRank (); ++i) {
604+ auto v = adaptor.getOperands ()[i];
605+ auto dimSize = type.getDimSize (i);
606+ if (dimSize <= 1 ) {
607+ Value nxt = rewriter.create <arith::CmpIOp>(
608+ loc, arith::CmpIPredicate::sgt, v, zero);
609+ prod = rewriter.create <arith::MulIOp>(loc, prod, nxt);
610+ } else {
611+ assert (nonUnitDim == -1 && " at most 1 non-unit expected" );
612+ nonUnitDim = i;
613+ }
614+ }
615+ prod =
616+ rewriter.create <arith::IndexCastOp>(loc, rewriter.getIndexType (), prod);
571617
572- Value product = adaptor. getOperands (). front ();
573- for ( unsigned i = 1 ; i < maskType. getRank (); ++i ) {
574- product = rewriter. create <mlir::arith::MulIOp>(
575- createMaskOp. getLoc (), product, adaptor. getOperands ()[i] );
618+ // Finally, multiply by the size in the dimension that is not unit.
619+ if (nonUnitDim != - 1 ) {
620+ Value v = adaptor. getOperands ()[nonUnitDim];
621+ prod = rewriter. create <arith::MulIOp>(loc, prod, v );
576622 }
577- Type flatMaskType = getTypeConverter ()->convertType (maskType);
578- auto newMask = rewriter.create <mlir::vector::CreateMaskOp>(
579- createMaskOp.getLoc (), flatMaskType, product);
580- rewriter.replaceOp (createMaskOp, newMask);
623+
624+ Type flatType = getTypeConverter ()->convertType (type);
625+ auto newMask =
626+ rewriter.create <mlir::vector::CreateMaskOp>(loc, flatType, prod);
627+ rewriter.replaceOp (maskOp, newMask);
581628 return success ();
582629 }
583630};
0 commit comments