Skip to content

Commit 06b0af0

Browse files
committed
[CIR][Lowering] Fix Vector Comparison Lowering with -fno-signed-char/unsigned operand
1 parent aed448e commit 06b0af0

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2052,9 +2052,24 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
20522052
auto elementType = elementTypeIfVector(op.getLhs().getType());
20532053
mlir::Value bitResult;
20542054
if (auto intType = mlir::dyn_cast<cir::IntType>(elementType)) {
2055+
2056+
auto isCIRZeroVector = [](mlir::Value value) {
2057+
if (auto constantOp = value.getDefiningOp<cir::ConstantOp>())
2058+
if (auto zeroAttr =
2059+
mlir::dyn_cast<cir::ZeroAttr>(constantOp.getValue()))
2060+
return true;
2061+
return false;
2062+
};
2063+
2064+
bool shouldUseSigned = intType.isSigned();
2065+
// Special treatment for sign-bit extraction patterns (lt comparison with
2066+
// zero), always use signed comparison to preserve the semantic intent
2067+
if (op.getKind() == cir::CmpOpKind::lt && isCIRZeroVector(op.getRhs()))
2068+
shouldUseSigned = true;
2069+
20552070
bitResult = rewriter.create<mlir::LLVM::ICmpOp>(
20562071
op.getLoc(),
2057-
convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()),
2072+
convertCmpKindToICmpPredicate(op.getKind(), shouldUseSigned),
20582073
adaptor.getLhs(), adaptor.getRhs());
20592074
} else if (mlir::isa<cir::FPTypeInterface>(elementType)) {
20602075
bitResult = rewriter.create<mlir::LLVM::FCmpOp>(

clang/test/CIR/Lowering/vec-cmp.cir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
!s16i = !cir.int<s, 16>
55
!u16i = !cir.int<u, 16>
6+
!u8i = !cir.int<u, 8>
67

78
cir.func @vec_cmp(%0: !cir.vector<!s16i x 16>, %1: !cir.vector<!s16i x 16>) -> () {
89
%2 = cir.vec.cmp(lt, %0, %1) : !cir.vector<!s16i x 16>, !cir.vector<!cir.int<u, 1> x 16>
@@ -14,3 +15,15 @@ cir.func @vec_cmp(%0: !cir.vector<!s16i x 16>, %1: !cir.vector<!s16i x 16>) -> (
1415
// MLIR-NEXT: %{{[0-9]+}} = llvm.icmp "slt" %arg0, %arg1 : vector<16xi16>
1516
// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : vector<16xi1> to i16
1617
// MLIR-NEXT: llvm.return
18+
19+
cir.func @vec_cmp_zero(%0: !cir.vector<!u8i x 16>) -> () {
20+
%1 = cir.const #cir.zero : !cir.vector<!u8i x 16>
21+
%2 = cir.vec.cmp(lt, %0, %1) : !cir.vector<!u8i x 16>, !cir.vector<!cir.int<u, 1> x 16>
22+
%3 = cir.cast(bitcast, %2 : !cir.vector<!cir.int<u, 1> x 16>), !cir.int<u, 16>
23+
24+
cir.return
25+
}
26+
27+
// MLIR: llvm.func @vec_cmp_zero
28+
// MLIR: %{{[0-9]+}} = llvm.icmp "slt" %arg0, %{{[0-9]+}} : vector<16xi8>
29+
// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : vector<16xi1> to i16

0 commit comments

Comments
 (0)