Skip to content

Commit 21e8647

Browse files
authored
[CIR][CIRGen][Builtin][Neon] Lower neon_vqadds_s32 (#1200)
This can't be simply implemented by our CIR Add via LLVM::AddOp, as i[t's saturated add.](https://godbolt.org/z/MxqGrj6fP)
1 parent cfe7c63 commit 21e8647

File tree

9 files changed

+126
-40
lines changed

9 files changed

+126
-40
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,15 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
419419
}
420420

421421
mlir::Value createSub(mlir::Value lhs, mlir::Value rhs, bool hasNUW = false,
422-
bool hasNSW = false) {
422+
bool hasNSW = false, bool saturated = false) {
423423
auto op = create<cir::BinOp>(lhs.getLoc(), lhs.getType(),
424424
cir::BinOpKind::Sub, lhs, rhs);
425425
if (hasNUW)
426426
op.setNoUnsignedWrap(true);
427427
if (hasNSW)
428428
op.setNoSignedWrap(true);
429+
if (saturated)
430+
op.setSaturated(true);
429431
return op;
430432
}
431433

@@ -438,13 +440,15 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
438440
}
439441

440442
mlir::Value createAdd(mlir::Value lhs, mlir::Value rhs, bool hasNUW = false,
441-
bool hasNSW = false) {
443+
bool hasNSW = false, bool saturated = false) {
442444
auto op = create<cir::BinOp>(lhs.getLoc(), lhs.getType(),
443445
cir::BinOpKind::Add, lhs, rhs);
444446
if (hasNUW)
445447
op.setNoUnsignedWrap(true);
446448
if (hasNSW)
447449
op.setNoSignedWrap(true);
450+
if (saturated)
451+
op.setSaturated(true);
448452
return op;
449453
}
450454

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,12 +1192,14 @@ def BinOp : CIR_Op<"binop", [Pure,
11921192
let arguments = (ins Arg<BinOpKind, "binop kind">:$kind,
11931193
CIR_AnyType:$lhs, CIR_AnyType:$rhs,
11941194
UnitAttr:$no_unsigned_wrap,
1195-
UnitAttr:$no_signed_wrap);
1195+
UnitAttr:$no_signed_wrap,
1196+
UnitAttr:$saturated);
11961197

11971198
let assemblyFormat = [{
11981199
`(` $kind `,` $lhs `,` $rhs `)`
11991200
(`nsw` $no_signed_wrap^)?
12001201
(`nuw` $no_unsigned_wrap^)?
1202+
(`sat` $saturated^)?
12011203
`:` type($lhs) attr-dict
12021204
}];
12031205

clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2834,7 +2834,7 @@ static mlir::Value emitCommonNeonSISDBuiltinExpr(
28342834
case NEON::BI__builtin_neon_vqaddh_u16:
28352835
llvm_unreachable(" neon_vqaddh_u16 NYI ");
28362836
case NEON::BI__builtin_neon_vqadds_s32:
2837-
llvm_unreachable(" neon_vqadds_s32 NYI ");
2837+
return builder.createAdd(ops[0], ops[1], false, false, true);
28382838
case NEON::BI__builtin_neon_vqadds_u32:
28392839
llvm_unreachable(" neon_vqadds_u32 NYI ");
28402840
case NEON::BI__builtin_neon_vqdmulhh_s16:
@@ -2983,7 +2983,7 @@ static mlir::Value emitCommonNeonSISDBuiltinExpr(
29832983
case NEON::BI__builtin_neon_vqsubh_u16:
29842984
llvm_unreachable(" neon_vqsubh_u16 NYI ");
29852985
case NEON::BI__builtin_neon_vqsubs_s32:
2986-
llvm_unreachable(" neon_vqsubs_s32 NYI ");
2986+
return builder.createSub(ops[0], ops[1], false, false, true);
29872987
case NEON::BI__builtin_neon_vqsubs_u32:
29882988
llvm_unreachable(" neon_vqsubs_u32 NYI ");
29892989
case NEON::BI__builtin_neon_vrecped_f64:

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3785,6 +3785,7 @@ LogicalResult cir::AtomicFetch::verify() {
37853785

37863786
LogicalResult cir::BinOp::verify() {
37873787
bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
3788+
bool saturated = getSaturated();
37883789

37893790
if (!isa<cir::IntType>(getType()) && noWrap)
37903791
return emitError()
@@ -3794,9 +3795,18 @@ LogicalResult cir::BinOp::verify() {
37943795
getKind() == cir::BinOpKind::Sub ||
37953796
getKind() == cir::BinOpKind::Mul;
37963797

3798+
bool saturatedOps =
3799+
getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;
3800+
37973801
if (noWrap && !noWrapOps)
37983802
return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
37993803
"'sub' and 'mul'";
3804+
if (saturated && !saturatedOps)
3805+
return emitError() << "The saturated flag is applicable to opcodes: 'add' "
3806+
"and 'sub'";
3807+
if (noWrap && saturated)
3808+
return emitError() << "The nsw/nuw flags and the saturated flag are "
3809+
"mutually exclusive";
38003810

38013811
bool complexOps =
38023812
getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2452,6 +2452,13 @@ CIRToLLVMBinOpLowering::getIntOverflowFlag(cir::BinOp op) const {
24522452
return mlir::LLVM::IntegerOverflowFlags::none;
24532453
}
24542454

2455+
static bool isIntTypeUnsigned(mlir::Type type) {
2456+
// TODO: Ideally, we should only need to check cir::IntType here.
2457+
return mlir::isa<cir::IntType>(type)
2458+
? mlir::cast<cir::IntType>(type).isUnsigned()
2459+
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
2460+
}
2461+
24552462
mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
24562463
cir::BinOp op, OpAdaptor adaptor,
24572464
mlir::ConversionPatternRewriter &rewriter) const {
@@ -2464,65 +2471,81 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
24642471
"operand type not supported yet");
24652472

24662473
auto llvmTy = getTypeConverter()->convertType(op.getType());
2474+
mlir::Type llvmEltTy =
2475+
mlir::isa<mlir::VectorType>(llvmTy)
2476+
? mlir::cast<mlir::VectorType>(llvmTy).getElementType()
2477+
: llvmTy;
24672478
auto rhs = adaptor.getRhs();
24682479
auto lhs = adaptor.getLhs();
24692480

24702481
type = elementTypeIfVector(type);
24712482

24722483
switch (op.getKind()) {
24732484
case cir::BinOpKind::Add:
2474-
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
2485+
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
2486+
if (op.getSaturated()) {
2487+
if (isIntTypeUnsigned(type)) {
2488+
rewriter.replaceOpWithNewOp<mlir::LLVM::UAddSat>(op, lhs, rhs);
2489+
break;
2490+
}
2491+
rewriter.replaceOpWithNewOp<mlir::LLVM::SAddSat>(op, lhs, rhs);
2492+
break;
2493+
}
24752494
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs,
24762495
getIntOverflowFlag(op));
2477-
else
2478-
rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmTy, lhs, rhs);
2496+
} else
2497+
rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, lhs, rhs);
24792498
break;
24802499
case cir::BinOpKind::Sub:
2481-
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
2500+
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
2501+
if (op.getSaturated()) {
2502+
if (isIntTypeUnsigned(type)) {
2503+
rewriter.replaceOpWithNewOp<mlir::LLVM::USubSat>(op, lhs, rhs);
2504+
break;
2505+
}
2506+
rewriter.replaceOpWithNewOp<mlir::LLVM::SSubSat>(op, lhs, rhs);
2507+
break;
2508+
}
24822509
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs,
24832510
getIntOverflowFlag(op));
2484-
else
2485-
rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, llvmTy, lhs, rhs);
2511+
} else
2512+
rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, lhs, rhs);
24862513
break;
24872514
case cir::BinOpKind::Mul:
2488-
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
2515+
if (mlir::isa<mlir::IntegerType>(llvmEltTy))
24892516
rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs,
24902517
getIntOverflowFlag(op));
24912518
else
2492-
rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, llvmTy, lhs, rhs);
2519+
rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, lhs, rhs);
24932520
break;
24942521
case cir::BinOpKind::Div:
2495-
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
2496-
auto isUnsigned = mlir::isa<cir::IntType>(type)
2497-
? mlir::cast<cir::IntType>(type).isUnsigned()
2498-
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
2522+
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
2523+
auto isUnsigned = isIntTypeUnsigned(type);
24992524
if (isUnsigned)
2500-
rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, llvmTy, lhs, rhs);
2525+
rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, lhs, rhs);
25012526
else
2502-
rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, llvmTy, lhs, rhs);
2527+
rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, lhs, rhs);
25032528
} else
2504-
rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, llvmTy, lhs, rhs);
2529+
rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, lhs, rhs);
25052530
break;
25062531
case cir::BinOpKind::Rem:
2507-
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
2508-
auto isUnsigned = mlir::isa<cir::IntType>(type)
2509-
? mlir::cast<cir::IntType>(type).isUnsigned()
2510-
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
2532+
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
2533+
auto isUnsigned = isIntTypeUnsigned(type);
25112534
if (isUnsigned)
2512-
rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, llvmTy, lhs, rhs);
2535+
rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, lhs, rhs);
25132536
else
2514-
rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, llvmTy, lhs, rhs);
2537+
rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, lhs, rhs);
25152538
} else
2516-
rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, llvmTy, lhs, rhs);
2539+
rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, lhs, rhs);
25172540
break;
25182541
case cir::BinOpKind::And:
2519-
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, llvmTy, lhs, rhs);
2542+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, lhs, rhs);
25202543
break;
25212544
case cir::BinOpKind::Or:
2522-
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, llvmTy, lhs, rhs);
2545+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, lhs, rhs);
25232546
break;
25242547
case cir::BinOpKind::Xor:
2525-
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, llvmTy, lhs, rhs);
2548+
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, lhs, rhs);
25262549
break;
25272550
}
25282551

clang/test/CIR/CodeGen/AArch64/neon.c

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9750,12 +9750,16 @@ poly16x8_t test_vmull_p8(poly8x8_t a, poly8x8_t b) {
97509750
// return vqaddh_s16(a, b);
97519751
// }
97529752

9753-
// NYI-LABEL: @test_vqadds_s32(
9754-
// NYI: [[VQADDS_S32_I:%.*]] = call i32 @llvm.aarch64.neon.sqadd.i32(i32 %a, i32 %b)
9755-
// NYI: ret i32 [[VQADDS_S32_I]]
9756-
// int32_t test_vqadds_s32(int32_t a, int32_t b) {
9757-
// return vqadds_s32(a, b);
9758-
// }
9753+
int32_t test_vqadds_s32(int32_t a, int32_t b) {
9754+
return vqadds_s32(a, b);
9755+
9756+
// CIR: vqadds_s32
9757+
// CIR: cir.binop(add, {{%.*}}, {{%.*}}) sat : !s32i
9758+
9759+
// LLVM:{{.*}}test_vqadds_s32(i32{{.*}}[[a:%.*]], i32{{.*}}[[b:%.*]])
9760+
// LLVM: [[VQADDS_S32_I:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[a]], i32 [[b]])
9761+
// LLVM: ret i32 [[VQADDS_S32_I]]
9762+
}
97599763

97609764
// NYI-LABEL: @test_vqaddd_s64(
97619765
// NYI: [[VQADDD_S64_I:%.*]] = call i64 @llvm.aarch64.neon.sqadd.i64(i64 %a, i64 %b)
@@ -9821,9 +9825,16 @@ poly16x8_t test_vmull_p8(poly8x8_t a, poly8x8_t b) {
98219825
// NYI-LABEL: @test_vqsubs_s32(
98229826
// NYI: [[VQSUBS_S32_I:%.*]] = call i32 @llvm.aarch64.neon.sqsub.i32(i32 %a, i32 %b)
98239827
// NYI: ret i32 [[VQSUBS_S32_I]]
9824-
// int32_t test_vqsubs_s32(int32_t a, int32_t b) {
9825-
// return vqsubs_s32(a, b);
9826-
// }
9828+
int32_t test_vqsubs_s32(int32_t a, int32_t b) {
9829+
return vqsubs_s32(a, b);
9830+
9831+
// CIR: vqsubs_s32
9832+
// CIR: cir.binop(sub, {{%.*}}, {{%.*}}) sat : !s32i
9833+
9834+
// LLVM:{{.*}}test_vqsubs_s32(i32{{.*}}[[a:%.*]], i32{{.*}}[[b:%.*]])
9835+
// LLVM: [[VQSUBS_S32_I:%.*]] = call i32 @llvm.ssub.sat.i32(i32 [[a]], i32 [[b]])
9836+
// LLVM: ret i32 [[VQSUBS_S32_I]]
9837+
}
98279838

98289839
// NYI-LABEL: @test_vqsubd_s64(
98299840
// NYI: [[VQSUBD_S64_I:%.*]] = call i64 @llvm.aarch64.neon.sqsub.i64(i64 %a, i64 %b)

clang/test/CIR/IR/invalid.cir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,33 @@ cir.func @bad_binop_for_nowrap(%x: !u32i, %y: !u32i) {
10911091

10921092
// -----
10931093

1094+
!u32i = !cir.int<u, 32>
1095+
1096+
cir.func @bad_binop_for_saturated(%x: !u32i, %y: !u32i) {
1097+
// expected-error@+1 {{The saturated flag is applicable to opcodes: 'add' and 'sub'}}
1098+
%0 = cir.binop(div, %x, %y) sat : !u32i
1099+
}
1100+
1101+
// -----
1102+
1103+
!s32i = !cir.int<s, 32>
1104+
1105+
cir.func @no_nsw_for_saturated(%x: !s32i, %y: !s32i) {
1106+
// expected-error@+1 {{The nsw/nuw flags and the saturated flag are mutually exclusive}}
1107+
%0 = cir.binop(add, %x, %y) nsw sat : !s32i
1108+
}
1109+
1110+
// -----
1111+
1112+
!s32i = !cir.int<s, 32>
1113+
1114+
cir.func @no_nuw_for_saturated(%x: !s32i, %y: !s32i) {
1115+
// expected-error@+1 {{The nsw/nuw flags and the saturated flag are mutually exclusive}}
1116+
%0 = cir.binop(add, %x, %y) nuw sat : !s32i
1117+
}
1118+
1119+
// -----
1120+
10941121
!s32i = !cir.int<s, 32>
10951122

10961123
module {

clang/test/CIR/Lowering/binop-signed-int.cir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ module {
5858
%33 = cir.load %1 : !cir.ptr<!s32i>, !s32i
5959
%34 = cir.binop(or, %32, %33) : !s32i
6060
// CHECK: = llvm.or
61+
%35 = cir.binop(add, %32, %33) sat: !s32i
62+
// CHECK: = llvm.intr.sadd.sat{{.*}}(i32, i32) -> i32
63+
%36 = cir.binop(sub, %32, %33) sat: !s32i
64+
// CHECK: = llvm.intr.ssub.sat{{.*}}(i32, i32) -> i32
6165
cir.store %34, %2 : !s32i, !cir.ptr<!s32i>
6266
cir.return
6367
}

clang/test/CIR/Lowering/binop-unsigned-int.cir

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ module {
4949
%33 = cir.load %1 : !cir.ptr<!u32i>, !u32i
5050
%34 = cir.binop(or, %32, %33) : !u32i
5151
cir.store %34, %2 : !u32i, !cir.ptr<!u32i>
52+
%35 = cir.binop(add, %32, %33) sat: !u32i
53+
%36 = cir.binop(sub, %32, %33) sat: !u32i
5254
cir.return
5355
}
5456
}
@@ -62,7 +64,8 @@ module {
6264
// MLIR: = llvm.shl
6365
// MLIR: = llvm.and
6466
// MLIR: = llvm.xor
65-
// MLIR: = llvm.or
67+
// MLIR: = llvm.intr.uadd.sat{{.*}}(i32, i32) -> i32
68+
// MLIR: = llvm.intr.usub.sat{{.*}}(i32, i32) -> i32
6669

6770
// LLVM: = mul i32
6871
// LLVM: = udiv i32
@@ -74,3 +77,5 @@ module {
7477
// LLVM: = and i32
7578
// LLVM: = xor i32
7679
// LLVM: = or i32
80+
// LLVM: = call i32 @llvm.uadd.sat.i32
81+
// LLVM: = call i32 @llvm.usub.sat.i32

0 commit comments

Comments
 (0)