Skip to content

Commit b4ae361

Browse files
committed
updates
1 parent 9bc615f commit b4ae361

File tree

2 files changed

+77
-24
lines changed

2 files changed

+77
-24
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/TypeUtilities.h"
2222
#include "mlir/Transforms/DialectConversion.h"
2323
#include "llvm/ADT/ArrayRef.h"
24+
#include "llvm/ADT/STLExtras.h"
2425
#include <algorithm>
2526
#include <cstdint>
2627
#include <numeric>
@@ -471,9 +472,12 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
471472
return containsScalableResult;
472473
}
473474

474-
static bool isLinearizableCreateMaskOp(vector::CreateMaskOp createMaskOp) {
475-
auto shape = createMaskOp.getType().getShape();
476-
return llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) <= 1;
475+
static bool
476+
isCreateMaskWithAtMostOneNonUnit(vector::CreateMaskOp createMaskOp) {
477+
ArrayRef<int64_t> shape = createMaskOp.getType().getShape();
478+
bool multipleNonUnitDim =
479+
llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) > 1;
480+
return !multipleNonUnitDim;
477481
}
478482

479483
static bool isNotLinearizable(Operation *op) {
@@ -493,7 +497,7 @@ static bool isNotLinearizable(Operation *op) {
493497
return true;
494498

495499
if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(op)) {
496-
if (!isLinearizableCreateMaskOp(createMaskOp)) {
500+
if (!isCreateMaskWithAtMostOneNonUnit(createMaskOp)) {
497501
return true;
498502
}
499503
}
@@ -540,7 +544,8 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
540544
});
541545
}
542546

543-
/// Linearize vector.create_mask with at most 1 non-unit dimension. Example:
547+
/// Linearize a vector.create_mask that has at most 1 non-unit dimension.
548+
/// Example:
544549
///
545550
/// ```
546551
/// %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
@@ -549,11 +554,30 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
549554
/// becomes
550555
///
551556
/// ```
552-
/// %0 = arith.muli %arg0, %arg1 : index
553-
/// %1 = arith.muli %0, %arg2 : index
554-
/// %2 = vector.create_mask %1: vector<16xi1>
557+
/// [...]
558+
/// %2 = vector.create_mask %prod: vector<16xi1>
555559
/// %3 = vector.shape_cast %2: vector<16xi1> to vector<1x16x1xi1>
556560
/// ```
561+
///
562+
/// where %prod above the product of the (clamped) dimension-wise masking ranges
563+
/// %arg0, %arg1, and %arg2.
564+
///
565+
/// This is equivalent to choosing the rank-1 masking range as:
566+
/// 1) %arg1 if %arg0 and %arg2 are stricty positive
567+
/// 2) 0 if either %arg0 or %arg2 are 0 or negative.
568+
///
569+
/// Specifically, %prod is obtained as
570+
///
571+
/// ```
572+
/// %true = arith.constant true
573+
/// %zero = arith.constant 0 : index
574+
/// %0 = arith.cmpi sgt, %arg0, %zero : index
575+
/// %1 = arith.muli %true, %0 : i1
576+
/// %2 = arith.cmpi sgt, %arg2, %zero : index
577+
/// %3 = arith.muli %1, %2 : i1
578+
/// %4 = arith.index_cast %3 : i1 to index
579+
/// %prod = arith.muli %4, %arg1 : index
580+
/// ```
557581
struct LinearizeVectorCreateMask final
558582
: OpConversionPattern<vector::CreateMaskOp> {
559583
using OpConversionPattern::OpConversionPattern;
@@ -563,21 +587,44 @@ struct LinearizeVectorCreateMask final
563587
: OpConversionPattern(typeConverter, context, benefit) {}
564588

565589
LogicalResult
566-
matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
590+
matchAndRewrite(vector::CreateMaskOp maskOp, OpAdaptor adaptor,
567591
ConversionPatternRewriter &rewriter) const override {
568592

569-
VectorType maskType = createMaskOp.getType();
570-
assert(isLinearizableCreateMaskOp(createMaskOp));
593+
VectorType type = maskOp.getType();
594+
assert(isCreateMaskWithAtMostOneNonUnit(maskOp) &&
595+
"expected linearizable create_mask");
596+
597+
Location loc = maskOp.getLoc();
598+
599+
// First, get the product of (clamped) mask sizes in the unit-dimensions.
600+
Value prod = rewriter.create<arith::ConstantIntOp>(loc, 1, 1);
601+
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
602+
int nonUnitDim = -1;
603+
for (unsigned i = 0; i < type.getRank(); ++i) {
604+
auto v = adaptor.getOperands()[i];
605+
auto dimSize = type.getDimSize(i);
606+
if (dimSize <= 1) {
607+
Value nxt = rewriter.create<arith::CmpIOp>(
608+
loc, arith::CmpIPredicate::sgt, v, zero);
609+
prod = rewriter.create<arith::MulIOp>(loc, prod, nxt);
610+
} else {
611+
assert(nonUnitDim == -1 && "at most 1 non-unit expected");
612+
nonUnitDim = i;
613+
}
614+
}
615+
prod =
616+
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), prod);
571617

572-
Value product = adaptor.getOperands().front();
573-
for (unsigned i = 1; i < maskType.getRank(); ++i) {
574-
product = rewriter.create<mlir::arith::MulIOp>(
575-
createMaskOp.getLoc(), product, adaptor.getOperands()[i]);
618+
// Finally, multiply by the size in the dimension that is not unit.
619+
if (nonUnitDim != -1) {
620+
Value v = adaptor.getOperands()[nonUnitDim];
621+
prod = rewriter.create<arith::MulIOp>(loc, prod, v);
576622
}
577-
Type flatMaskType = getTypeConverter()->convertType(maskType);
578-
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
579-
createMaskOp.getLoc(), flatMaskType, product);
580-
rewriter.replaceOp(createMaskOp, newMask);
623+
624+
Type flatType = getTypeConverter()->convertType(type);
625+
auto newMask =
626+
rewriter.create<mlir::vector::CreateMaskOp>(loc, flatType, prod);
627+
rewriter.replaceOp(maskOp, newMask);
581628
return success();
582629
}
583630
};

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,17 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
350350

351351
// CHECK-LABEL: linearize_create_mask
352352
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> vector<1x16x1xi1>
353-
// CHECK: %[[MULI1:.*]] = arith.muli %[[ARG0]], %[[ARG1]] : index
354-
// CHECK: %[[MULI2:.*]] = arith.muli %[[MULI1]], %[[ARG2]] : index
355-
// CHECK: %[[MASK:.*]] = vector.create_mask %[[MULI2]] : vector<16xi1>
356-
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16x1xi1>
357-
// CHECK: return %[[CAST]] : vector<1x16x1xi1>
353+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
354+
// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
355+
// CHECK: %[[CMP0:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
356+
// CHECK: %[[MUL0:.*]] = arith.muli %[[TRUE]], %[[CMP0]] : i1
357+
// CHECK: %[[CMP1:.*]] = arith.cmpi sgt, %[[ARG2]], %[[C0]] : index
358+
// CHECK: %[[MUL1:.*]] = arith.muli %[[MUL0]], %[[CMP1]] : i1
359+
// CHECK: %[[CAST:.*]] = arith.index_cast %[[MUL1]] : i1 to index
360+
// CHECK: %[[MUL2:.*]] = arith.muli %[[CAST]], %[[ARG1]] : index
361+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[MUL2]] : vector<16xi1>
362+
// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16x1xi1>
363+
// CHECK: return %[[CAST2]] : vector<1x16x1xi1>
358364
func.func @linearize_create_mask(%arg0 : index, %arg1 : index, %arg2 : index) -> vector<1x16x1xi1> {
359365
%0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
360366
return %0 : vector<1x16x1xi1>

0 commit comments

Comments
 (0)