Skip to content

Commit 994ceb5

Browse files
authored
[CIR][Dialect] Add BinOpKind_Max (#1201)
This would facilitate implementation of neon intrinsic `neon_vmax_v` and `__builtin_elementwise_max`, and potentially future optimizations. CIR BinOp supports vector type. Floating point has already been supported by FMaxOp.
1 parent 0a1b06c commit 994ceb5

File tree

4 files changed

+30
-1
lines changed

4 files changed

+30
-1
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,14 +1157,15 @@ def BinOpKind_Sub : I32EnumAttrCase<"Sub", 5, "sub">;
11571157
def BinOpKind_And : I32EnumAttrCase<"And", 8, "and">;
11581158
def BinOpKind_Xor : I32EnumAttrCase<"Xor", 9, "xor">;
11591159
def BinOpKind_Or : I32EnumAttrCase<"Or", 10, "or">;
1160+
def BinOpKind_Max : I32EnumAttrCase<"Max", 11, "max">;
11601161

11611162
def BinOpKind : I32EnumAttr<
11621163
"BinOpKind",
11631164
"binary operation (arith and logic) kind",
11641165
[BinOpKind_Mul, BinOpKind_Div, BinOpKind_Rem,
11651166
BinOpKind_Add, BinOpKind_Sub,
11661167
BinOpKind_And, BinOpKind_Xor,
1167-
BinOpKind_Or]> {
1168+
BinOpKind_Or, BinOpKind_Max]> {
11681169
let cppNamespace = "::cir";
11691170
}
11701171

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,6 +2574,15 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
25742574
case cir::BinOpKind::Xor:
25752575
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, lhs, rhs);
25762576
break;
2577+
case cir::BinOpKind::Max:
2578+
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
2579+
auto isUnsigned = isIntTypeUnsigned(type);
2580+
if (isUnsigned)
2581+
rewriter.replaceOpWithNewOp<mlir::LLVM::UMaxOp>(op, llvmTy, lhs, rhs);
2582+
else
2583+
rewriter.replaceOpWithNewOp<mlir::LLVM::SMaxOp>(op, llvmTy, lhs, rhs);
2584+
}
2585+
break;
25772586
}
25782587

25792588
return mlir::LogicalResult::success();

clang/test/CIR/Lowering/binop-signed-int.cir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ module {
77
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
88
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64}
99
%2 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
10+
%100 = cir.alloca !cir.vector<!s32i x 2>, !cir.ptr<!cir.vector<!s32i x 2>>, ["vec1", init] {alignment = 8 : i64}
11+
%101 = cir.alloca !cir.vector<!s32i x 2>, !cir.ptr<!cir.vector<!s32i x 2>>, ["vec2", init] {alignment = 8 : i64}
1012
%3 = cir.const #cir.int<2> : !s32i cir.store %3, %0 : !s32i, !cir.ptr<!s32i>
1113
%4 = cir.const #cir.int<1> : !s32i cir.store %4, %1 : !s32i, !cir.ptr<!s32i>
1214
%5 = cir.load %0 : !cir.ptr<!s32i>, !s32i
@@ -63,6 +65,12 @@ module {
6365
%36 = cir.binop(sub, %32, %33) sat: !s32i
6466
// CHECK: = llvm.intr.ssub.sat{{.*}}(i32, i32) -> i32
6567
cir.store %34, %2 : !s32i, !cir.ptr<!s32i>
68+
%37 = cir.binop(max, %32, %33) : !s32i
69+
// CHECK: = llvm.intr.smax
70+
%38 = cir.load %100 : !cir.ptr<!cir.vector<!s32i x 2>>, !cir.vector<!s32i x 2>
71+
%39 = cir.load %101 : !cir.ptr<!cir.vector<!s32i x 2>>, !cir.vector<!s32i x 2>
72+
%40 = cir.binop(max, %38, %39) : !cir.vector<!s32i x 2>
73+
// CHECK: = llvm.intr.smax({{%.*}}, {{%.*}}) : (vector<2xi32>, vector<2xi32>) -> vector<2xi32>
6674
cir.return
6775
}
6876
}

clang/test/CIR/Lowering/binop-unsigned-int.cir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ module {
77
%0 = cir.alloca !u32i, !cir.ptr<!u32i>, ["a", init] {alignment = 4 : i64}
88
%1 = cir.alloca !u32i, !cir.ptr<!u32i>, ["b", init] {alignment = 4 : i64}
99
%2 = cir.alloca !u32i, !cir.ptr<!u32i>, ["x", init] {alignment = 4 : i64}
10+
%100 = cir.alloca !cir.vector<!u32i x 2>, !cir.ptr<!cir.vector<!u32i x 2>>, ["vec1", init] {alignment = 8 : i64}
11+
%101 = cir.alloca !cir.vector<!u32i x 2>, !cir.ptr<!cir.vector<!u32i x 2>>, ["vec2", init] {alignment = 8 : i64}
1012
%3 = cir.const #cir.int<2> : !u32i cir.store %3, %0 : !u32i, !cir.ptr<!u32i>
1113
%4 = cir.const #cir.int<1> : !u32i cir.store %4, %1 : !u32i, !cir.ptr<!u32i>
1214
%5 = cir.load %0 : !cir.ptr<!u32i>, !u32i
@@ -51,6 +53,10 @@ module {
5153
cir.store %34, %2 : !u32i, !cir.ptr<!u32i>
5254
%35 = cir.binop(add, %32, %33) sat: !u32i
5355
%36 = cir.binop(sub, %32, %33) sat: !u32i
56+
%37 = cir.binop(max, %32, %33) : !u32i
57+
%38 = cir.load %100 : !cir.ptr<!cir.vector<!u32i x 2>>, !cir.vector<!u32i x 2>
58+
%39 = cir.load %101 : !cir.ptr<!cir.vector<!u32i x 2>>, !cir.vector<!u32i x 2>
59+
%40 = cir.binop(max, %38, %39) : !cir.vector<!u32i x 2>
5460
cir.return
5561
}
5662
}
@@ -64,8 +70,11 @@ module {
6470
// MLIR: = llvm.shl
6571
// MLIR: = llvm.and
6672
// MLIR: = llvm.xor
73+
// MLIR: = llvm.or
6774
// MLIR: = llvm.intr.uadd.sat{{.*}}(i32, i32) -> i32
6875
// MLIR: = llvm.intr.usub.sat{{.*}}(i32, i32) -> i32
76+
// MLIR: = llvm.intr.umax
77+
// MLIR: = llvm.intr.umax
6978

7079
// LLVM: = mul i32
7180
// LLVM: = udiv i32
@@ -79,3 +88,5 @@ module {
7988
// LLVM: = or i32
8089
// LLVM: = call i32 @llvm.uadd.sat.i32
8190
// LLVM: = call i32 @llvm.usub.sat.i32
91+
// LLVM: = call i32 @llvm.umax.i32
92+
// LLVM: = call <2 x i32> @llvm.umax.v2i32

0 commit comments

Comments
 (0)