Skip to content

Commit 9bc615f

Browse files
committed
linearize create_mask
1 parent 8b9ae65 commit 9bc615f

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "mlir/Dialect/Arith/IR/Arith.h"
1314
#include "mlir/Dialect/UB/IR/UBOps.h"
1415
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1516
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -20,6 +21,7 @@
2021
#include "mlir/IR/TypeUtilities.h"
2122
#include "mlir/Transforms/DialectConversion.h"
2223
#include "llvm/ADT/ArrayRef.h"
24+
#include <algorithm>
2325
#include <cstdint>
2426
#include <numeric>
2527
#include <optional>
@@ -469,6 +471,11 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
469471
return containsScalableResult;
470472
}
471473

474+
static bool isLinearizableCreateMaskOp(vector::CreateMaskOp createMaskOp) {
475+
auto shape = createMaskOp.getType().getShape();
476+
return llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) <= 1;
477+
}
478+
472479
static bool isNotLinearizable(Operation *op) {
473480

474481
// Only ops that are in the vector dialect, are ConstantLike, or
@@ -485,6 +492,12 @@ static bool isNotLinearizable(Operation *op) {
485492
if (isNotLinearizableBecauseScalable(op))
486493
return true;
487494

495+
if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(op)) {
496+
if (!isLinearizableCreateMaskOp(createMaskOp)) {
497+
return true;
498+
}
499+
}
500+
488501
return false;
489502
}
490503

@@ -527,12 +540,54 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
527540
});
528541
}
529542

543+
/// Linearize vector.create_mask with at most 1 non-unit dimension. Example:
544+
///
545+
/// ```
546+
/// %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
547+
/// ```
548+
///
549+
/// becomes
550+
///
551+
/// ```
552+
/// %0 = arith.muli %arg0, %arg1 : index
553+
/// %1 = arith.muli %0, %arg2 : index
554+
/// %2 = vector.create_mask %1: vector<16xi1>
555+
/// %3 = vector.shape_cast %2: vector<16xi1> to vector<1x16x1xi1>
556+
/// ```
557+
struct LinearizeVectorCreateMask final
558+
: OpConversionPattern<vector::CreateMaskOp> {
559+
using OpConversionPattern::OpConversionPattern;
560+
561+
LinearizeVectorCreateMask(const TypeConverter &typeConverter,
562+
MLIRContext *context, PatternBenefit benefit = 1)
563+
: OpConversionPattern(typeConverter, context, benefit) {}
564+
565+
LogicalResult
566+
matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
567+
ConversionPatternRewriter &rewriter) const override {
568+
569+
VectorType maskType = createMaskOp.getType();
570+
assert(isLinearizableCreateMaskOp(createMaskOp));
571+
572+
Value product = adaptor.getOperands().front();
573+
for (unsigned i = 1; i < maskType.getRank(); ++i) {
574+
product = rewriter.create<mlir::arith::MulIOp>(
575+
createMaskOp.getLoc(), product, adaptor.getOperands()[i]);
576+
}
577+
Type flatMaskType = getTypeConverter()->convertType(maskType);
578+
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
579+
createMaskOp.getLoc(), flatMaskType, product);
580+
rewriter.replaceOp(createMaskOp, newMask);
581+
return success();
582+
}
583+
};
584+
530585
void mlir::vector::populateVectorLinearizeBasePatterns(
531586
const TypeConverter &typeConverter, const ConversionTarget &target,
532587
RewritePatternSet &patterns) {
533588
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
534-
LinearizeVectorBitCast, LinearizeVectorSplat>(
535-
typeConverter, patterns.getContext());
589+
LinearizeVectorCreateMask, LinearizeVectorBitCast,
590+
LinearizeVectorSplat>(typeConverter, patterns.getContext());
536591
}
537592

538593
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,17 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
345345
%0 = vector.splat %arg0 : vector<4x[2]xi32>
346346
return %0 : vector<4x[2]xi32>
347347
}
348+
349+
// -----
350+
351+
// CHECK-LABEL: linearize_create_mask
352+
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> vector<1x16x1xi1>
353+
// CHECK: %[[MULI1:.*]] = arith.muli %[[ARG0]], %[[ARG1]] : index
354+
// CHECK: %[[MULI2:.*]] = arith.muli %[[MULI1]], %[[ARG2]] : index
355+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[MULI2]] : vector<16xi1>
356+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16x1xi1>
357+
// CHECK: return %[[CAST]] : vector<1x16x1xi1>
358+
func.func @linearize_create_mask(%arg0 : index, %arg1 : index, %arg2 : index) -> vector<1x16x1xi1> {
359+
%0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
360+
return %0 : vector<1x16x1xi1>
361+
}

0 commit comments

Comments
 (0)