diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index b9cef003fa365..86bbbc2196a8b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -20,6 +21,8 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include #include #include #include @@ -469,6 +472,14 @@ static bool isNotLinearizableBecauseScalable(Operation *op) { return containsScalableResult; } +static bool +isCreateMaskWithAtMostOneNonUnit(vector::CreateMaskOp createMaskOp) { + ArrayRef shape = createMaskOp.getType().getShape(); + bool multipleNonUnitDim = + llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) > 1; + return !multipleNonUnitDim; +} + static bool isNotLinearizable(Operation *op) { // Only ops that are in the vector dialect, are ConstantLike, or @@ -485,6 +496,12 @@ static bool isNotLinearizable(Operation *op) { if (isNotLinearizableBecauseScalable(op)) return true; + if (auto createMaskOp = dyn_cast(op)) { + if (!isCreateMaskWithAtMostOneNonUnit(createMaskOp)) { + return true; + } + } + return false; } @@ -527,12 +544,95 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, }); } +/// Linearize a vector.create_mask that has at most 1 non-unit dimension. +/// For example, +/// ``` +/// %mask3 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1> +/// ``` +/// +/// becomes, +/// ``` +/// [...] +/// %mask1 = vector.create_mask %prod: vector<16xi1> +/// %mask3 = vector.shape_cast %mask1: vector<16xi1> to vector<1x16x1xi1> +/// ``` +/// +/// where %prod above the product of the (clamped) dimension-wise masking ranges +/// %arg0, %arg1, and %arg2. +/// +/// This is equivalent to choosing the rank-1 masking range as: +/// 1) %arg1 if %arg0 and %arg2 are stricty positive +/// 2) 0 if either %arg0 or %arg2 are 0 or negative. +/// +/// Specifically, %prod is obtained as +/// +/// ``` +/// %true = arith.constant true +/// %zero = arith.constant 0 : index +/// %0 = arith.cmpi sgt, %arg0, %zero : index +/// %1 = arith.muli %true, %0 : i1 +/// %2 = arith.cmpi sgt, %arg2, %zero : index +/// %3 = arith.muli %1, %2 : i1 +/// %4 = arith.index_cast %3 : i1 to index +/// %prod = arith.muli %4, %arg1 : index +/// ``` +struct LinearizeVectorCreateMask final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorCreateMask(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::CreateMaskOp maskOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + VectorType type = maskOp.getType(); + assert(isCreateMaskWithAtMostOneNonUnit(maskOp) && + "expected linearizable create_mask"); + + Location loc = maskOp.getLoc(); + + // First, get the product of (clamped) mask sizes in the unit-dimensions. + Value prod = rewriter.create(loc, 1, 1); + Value zero = rewriter.create(loc, 0); + int nonUnitDim = -1; + for (unsigned i = 0; i < type.getRank(); ++i) { + Value dimRange = adaptor.getOperands()[i]; + int64_t dimSize = type.getDimSize(i); + if (dimSize <= 1) { + Value nxt = rewriter.create( + loc, arith::CmpIPredicate::sgt, dimRange, zero); + prod = rewriter.create(loc, prod, nxt); + } else { + assert(nonUnitDim == -1 && "at most 1 non-unit expected"); + nonUnitDim = i; + } + } + prod = + rewriter.create(loc, rewriter.getIndexType(), prod); + + // Finally, multiply by the size in the dimension that is not unit. + if (nonUnitDim != -1) { + Value v = adaptor.getOperands()[nonUnitDim]; + prod = rewriter.create(loc, prod, v); + } + + Type flatType = getTypeConverter()->convertType(type); + auto newMask = + rewriter.create(loc, flatType, prod); + rewriter.replaceOp(maskOp, newMask); + return success(); + } +}; + void mlir::vector::populateVectorLinearizeBasePatterns( const TypeConverter &typeConverter, const ConversionTarget &target, RewritePatternSet &patterns) { patterns.add( - typeConverter, patterns.getContext()); + LinearizeVectorCreateMask, LinearizeVectorBitCast, + LinearizeVectorSplat>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 01ad1ac48b012..6437b5eefa9bb 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -345,3 +345,23 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { %0 = vector.splat %arg0 : vector<4x[2]xi32> return %0 : vector<4x[2]xi32> } + +// ----- + +// CHECK-LABEL: linearize_create_mask +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> vector<1x16x1xi1> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[TRUE:.*]] = arith.constant true +// CHECK: %[[CMP0:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index +// CHECK: %[[MUL0:.*]] = arith.muli %[[TRUE]], %[[CMP0]] : i1 +// CHECK: %[[CMP1:.*]] = arith.cmpi sgt, %[[ARG2]], %[[C0]] : index +// CHECK: %[[MUL1:.*]] = arith.muli %[[MUL0]], %[[CMP1]] : i1 +// CHECK: %[[CAST:.*]] = arith.index_cast %[[MUL1]] : i1 to index +// CHECK: %[[MUL2:.*]] = arith.muli %[[CAST]], %[[ARG1]] : index +// CHECK: %[[MASK:.*]] = vector.create_mask %[[MUL2]] : vector<16xi1> +// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16x1xi1> +// CHECK: return %[[CAST2]] : vector<1x16x1xi1> +func.func @linearize_create_mask(%arg0 : index, %arg1 : index, %arg2 : index) -> vector<1x16x1xi1> { + %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1> + return %0 : vector<1x16x1xi1> +}