Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,7 @@ def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
let summary = "signed integer maximum operation";
let hasFolder = 1;
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand All @@ -1017,6 +1018,7 @@ def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui", [Commutative]> {
let summary = "unsigned integer maximum operation";
let hasFolder = 1;
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1067,6 +1069,7 @@ def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
let summary = "signed integer minimum operation";
let hasFolder = 1;
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand All @@ -1076,6 +1079,7 @@ def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
let summary = "unsigned integer minimum operation";
let hasFolder = 1;
let hasCanonicalizer = 1;
}


Expand Down
68 changes: 68 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;

// Select signed min value of two integer attributes and store to the result
def SMinIntAttrs : NativeCodeCall<"sminIntegerAttrs($_builder, $0, $1, $2)">;

// Select unsigned min value of two integer attributes and store to the result
def UMinIntAttrs : NativeCodeCall<"uminIntegerAttrs($_builder, $0, $1, $2)">;

// Select signed max value of two integer attributes and store to the result
def SMaxIntAttrs : NativeCodeCall<"smaxIntegerAttrs($_builder, $0, $1, $2)">;

// Select unsigned max value of two integer attributes and store to the result
def UMaxIntAttrs : NativeCodeCall<"umaxIntegerAttrs($_builder, $0, $1, $2)">;

// Merge overflow flags from 2 ops, selecting the most conservative combination.
def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;

Expand Down Expand Up @@ -202,6 +214,62 @@ def MulUIExtendedToMulI :
[(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;

//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//

// maxsi is commutative and will be canonicalized to have its constants appear
// as the second operand.

// maxsi(maxsi(x, c0), c1) -> maxsi(x, maxsi(c0, c1))
def MaxSIMaxSIConstant :
Pat<(Arith_MaxSIOp:$res
(Arith_MaxSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_MaxSIOp $x, (Arith_ConstantOp (SMaxIntAttrs $res, $c0, $c1)))>;

//===----------------------------------------------------------------------===//
// MaxUIOp
//===----------------------------------------------------------------------===//

// maxui is commutative and will be canonicalized to have its constants appear
// as the second operand.

// maxui(maxui(x, c0), c1) -> maxui(x, maxui(c0, c1))
def MaxUIMaxUIConstant :
Pat<(Arith_MaxUIOp:$res
(Arith_MaxUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_MaxUIOp $x, (Arith_ConstantOp (UMaxIntAttrs $res, $c0, $c1)))>;

//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//

// minsi is commutative and will be canonicalized to have its constants appear
// as the second operand.

// minsi(minsi(x, c0), c1) -> minsi(x, minsi(c0, c1))
def MinSIMinSIConstant :
Pat<(Arith_MinSIOp:$res
(Arith_MinSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_MinSIOp $x, (Arith_ConstantOp (SMinIntAttrs $res, $c0, $c1)))>;

//===----------------------------------------------------------------------===//
// MinUIOp
//===----------------------------------------------------------------------===//

// minui is commutative and will be canonicalized to have its constants appear
// as the second operand.

// minui(minui(x, c0), c1) -> minui(x, minui(c0, c1))
def MinUIMinUIConstant :
Pat<(Arith_MinUIOp:$res
(Arith_MinUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_MinUIOp $x, (Arith_ConstantOp (UMinIntAttrs $res, $c0, $c1)))>;

//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
Expand Down
40 changes: 40 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,26 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}

static IntegerAttr sminIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smin);
}

static IntegerAttr uminIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umin);
}

static IntegerAttr smaxIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smax);
}

static IntegerAttr umaxIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umax);
}

// Merge overflow flags from 2 ops, selecting the most conservative combination.
static IntegerOverflowFlagsAttr
mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
Expand Down Expand Up @@ -1162,6 +1182,11 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
});
}

void arith::MaxSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<MaxSIMaxSIConstant>(context);
}

//===----------------------------------------------------------------------===//
// MaxUIOp
//===----------------------------------------------------------------------===//
Expand All @@ -1187,6 +1212,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
});
}

void arith::MaxUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<MaxUIMaxUIConstant>(context);
}

//===----------------------------------------------------------------------===//
// MinimumFOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1248,6 +1278,11 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
});
}

void arith::MinSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<MinSIMinSIConstant>(context);
}

//===----------------------------------------------------------------------===//
// MinUIOp
//===----------------------------------------------------------------------===//
Expand All @@ -1273,6 +1308,11 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
});
}

void arith::MinUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<MinUIMinUIConstant>(context);
}

//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
Expand Down
97 changes: 96 additions & 1 deletion mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1952,6 +1952,30 @@ func.func @bitcastChain(%arg: i16) -> f16 {

// -----

// CHECK-LABEL: @maxsiMaxsiConst1
// CHECK: %[[C42:.+]] = arith.constant 42 : i32
// CHECK: %[[RES:.+]] = arith.maxsi %arg0, %[[C42]] : i32
// CHECK: return %[[RES]]
func.func @maxsiMaxsiConst1(%arg0: i32) -> i32 {
%c17 = arith.constant 17 : i32
%c42 = arith.constant 42 : i32
%max1 = arith.maxsi %arg0, %c17 : i32
%max2 = arith.maxsi %max1, %c42 : i32
return %max2 : i32
}

// CHECK-LABEL: @maxsiMaxsiConst2
// CHECK: %[[C21:.+]] = arith.constant 21 : i32
// CHECK: %[[RES:.+]] = arith.maxsi %arg0, %[[C21]] : i32
// CHECK: return %[[RES]]
func.func @maxsiMaxsiConst2(%arg0: i32) -> i32 {
%c7 = arith.constant 7 : i32
%c21 = arith.constant 21 : i32
%max1 = arith.maxsi %arg0, %c7 : i32
%max2 = arith.maxsi %c21, %max1 : i32
return %max2 : i32
}

// CHECK-LABEL: test_maxsi
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
Expand Down Expand Up @@ -1986,6 +2010,30 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {

// -----

// CHECK-LABEL: @maxuiMaxuiConst1
// CHECK: %[[C42:.+]] = arith.constant 42 : index
// CHECK: %[[RES:.+]] = arith.maxui %arg0, %[[C42]] : index
// CHECK: return %[[RES]]
func.func @maxuiMaxuiConst1(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%max1 = arith.maxui %arg0, %c17 : index
%max2 = arith.maxui %max1, %c42 : index
return %max2 : index
}

// CHECK-LABEL: @maxuiMaxuiConst2
// CHECK: %[[C21:.+]] = arith.constant 21 : index
// CHECK: %[[RES:.+]] = arith.maxui %arg0, %[[C21]] : index
// CHECK: return %[[RES]]
func.func @maxuiMaxuiConst2(%arg0: index) -> index {
%c7 = arith.constant 7 : index
%c21 = arith.constant 21 : index
%max1 = arith.maxui %arg0, %c7 : index
%max2 = arith.maxui %c21, %max1 : index
return %max2 : index
}

// CHECK-LABEL: test_maxui
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
Expand Down Expand Up @@ -2020,6 +2068,30 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {

// -----

// CHECK-LABEL: @minsiMinsiConst1
// CHECK: %[[C17:.+]] = arith.constant 17 : i32
// CHECK: %[[RES:.+]] = arith.minsi %arg0, %[[C17]] : i32
// CHECK: return %[[RES]]
func.func @minsiMinsiConst1(%arg0: i32) -> i32 {
%c17 = arith.constant 17 : i32
%c42 = arith.constant 42 : i32
%min1 = arith.minsi %arg0, %c17 : i32
%min2 = arith.minsi %min1, %c42 : i32
return %min2 : i32
}

// CHECK-LABEL: @minsiMinsiConst2
// CHECK: %[[C7:.+]] = arith.constant 7 : i32
// CHECK: %[[RES:.+]] = arith.minsi %arg0, %[[C7]] : i32
// CHECK: return %[[RES]]
func.func @minsiMinsiConst2(%arg0: i32) -> i32 {
%c7 = arith.constant 7 : i32
%c21 = arith.constant 21 : i32
%min1 = arith.minsi %arg0, %c7 : i32
%min2 = arith.minsi %c21, %min1 : i32
return %min2 : i32
}

// CHECK-LABEL: test_minsi
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
Expand Down Expand Up @@ -2054,6 +2126,30 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {

// -----

// CHECK-LABEL: @minuiMinuiConst1
// CHECK: %[[C17:.+]] = arith.constant 17 : index
// CHECK: %[[RES:.+]] = arith.minui %arg0, %[[C17]] : index
// CHECK: return %[[RES]]
func.func @minuiMinuiConst1(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
%min1 = arith.minui %arg0, %c17 : index
%min2 = arith.minui %min1, %c42 : index
return %min2 : index
}

// CHECK-LABEL: @minuiMinuiConst2
// CHECK: %[[C7:.+]] = arith.constant 7 : index
// CHECK: %[[RES:.+]] = arith.minui %arg0, %[[C7]] : index
// CHECK: return %[[RES]]
func.func @minuiMinuiConst2(%arg0: index) -> index {
%c7 = arith.constant 7 : index
%c21 = arith.constant 21 : index
%min1 = arith.minui %arg0, %c7 : index
%min2 = arith.minui %c21, %min1 : index
return %min2 : index
}

// CHECK-LABEL: test_minui
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
Expand Down Expand Up @@ -3377,4 +3473,3 @@ func.func @unreachable() {
%add = arith.addi %add, %c1_i64 : i64
cf.br ^unreachable
}