Skip to content

Commit e040553

Browse files
committed
[CIR][Lowering] Fix Vector Comparison Lowering with -fno-signed-char/unsigned operand
1 parent aed448e commit e040553

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -731,8 +731,9 @@ mlir::Value CirAttrToValue::visitCirAttr(cir::GlobalViewAttr globalAttr) {
731731
}
732732
auto resTy = addrOp.getType();
733733
auto eltTy = converter->convertType(sourceType);
734-
addrOp = rewriter.create<mlir::LLVM::GEPOp>(loc, resTy, eltTy, addrOp,
735-
indices, mlir::LLVM::GEPNoWrapFlags::inbounds);
734+
addrOp = rewriter.create<mlir::LLVM::GEPOp>(
735+
loc, resTy, eltTy, addrOp, indices,
736+
mlir::LLVM::GEPNoWrapFlags::inbounds);
736737
}
737738

738739
if (auto intTy = mlir::dyn_cast<cir::IntType>(globalAttr.getType())) {
@@ -1205,8 +1206,9 @@ mlir::LogicalResult CIRToLLVMVTTAddrPointOpLowering::matchAndRewrite(
12051206
offsets.push_back(0);
12061207
offsets.push_back(adaptor.getOffset());
12071208
}
1208-
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, resultType, eltType,
1209-
llvmAddr, offsets, mlir::LLVM::GEPNoWrapFlags::inbounds);
1209+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
1210+
op, resultType, eltType, llvmAddr, offsets,
1211+
mlir::LLVM::GEPNoWrapFlags::inbounds);
12101212
return mlir::success();
12111213
}
12121214

@@ -2052,9 +2054,24 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
20522054
auto elementType = elementTypeIfVector(op.getLhs().getType());
20532055
mlir::Value bitResult;
20542056
if (auto intType = mlir::dyn_cast<cir::IntType>(elementType)) {
2057+
2058+
auto isCIRZeroVector = [](mlir::Value value) {
2059+
if (auto constantOp = value.getDefiningOp<cir::ConstantOp>())
2060+
if (auto zeroAttr =
2061+
mlir::dyn_cast<cir::ZeroAttr>(constantOp.getValue()))
2062+
return true;
2063+
return false;
2064+
};
2065+
2066+
bool shouldUseSigned = intType.isSigned();
2067+
// Special treatment for sign-bit extraction patterns (lt comparison with
2068+
// zero), always use signed comparison to preserve the semantic intent
2069+
if (op.getKind() == cir::CmpOpKind::lt && isCIRZeroVector(op.getRhs()))
2070+
shouldUseSigned = true;
2071+
20552072
bitResult = rewriter.create<mlir::LLVM::ICmpOp>(
20562073
op.getLoc(),
2057-
convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()),
2074+
convertCmpKindToICmpPredicate(op.getKind(), shouldUseSigned),
20582075
adaptor.getLhs(), adaptor.getRhs());
20592076
} else if (mlir::isa<cir::FPTypeInterface>(elementType)) {
20602077
bitResult = rewriter.create<mlir::LLVM::FCmpOp>(
@@ -3881,8 +3898,9 @@ mlir::LogicalResult CIRToLLVMVTableAddrPointOpLowering::matchAndRewrite(
38813898
op.getAddressPointAttr().getOffset()};
38823899

38833900
assert(eltType && "Shouldn't ever be missing an eltType here");
3884-
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, targetType, eltType,
3885-
symAddr, offsets, mlir::LLVM::GEPNoWrapFlags::inbounds);
3901+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
3902+
op, targetType, eltType, symAddr, offsets,
3903+
mlir::LLVM::GEPNoWrapFlags::inbounds);
38863904

38873905
return mlir::success();
38883906
}
@@ -3908,7 +3926,8 @@ mlir::LogicalResult CIRToLLVMVTableGetVirtualFnAddrOpLowering::matchAndRewrite(
39083926
llvm::SmallVector<mlir::LLVM::GEPArg> offsets =
39093927
llvm::SmallVector<mlir::LLVM::GEPArg>{op.getIndex()};
39103928
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
3911-
op, targetType, eltType, adaptor.getVptr(), offsets, mlir::LLVM::GEPNoWrapFlags::inbounds);
3929+
op, targetType, eltType, adaptor.getVptr(), offsets,
3930+
mlir::LLVM::GEPNoWrapFlags::inbounds);
39123931
return mlir::success();
39133932
}
39143933

@@ -4000,7 +4019,9 @@ mlir::LogicalResult CIRToLLVMInlineAsmOpLowering::matchAndRewrite(
40004019
op, llResTy, llvmOperands, op.getAsmStringAttr(), op.getConstraintsAttr(),
40014020
op.getSideEffectsAttr(),
40024021
/*is_align_stack*/ mlir::UnitAttr(),
4003-
/*tail_call_kind*/ mlir::LLVM::TailCallKindAttr::get(getContext(), mlir::LLVM::tailcallkind::TailCallKind::None),
4022+
/*tail_call_kind*/
4023+
mlir::LLVM::TailCallKindAttr::get(
4024+
getContext(), mlir::LLVM::tailcallkind::TailCallKind::None),
40044025
mlir::LLVM::AsmDialectAttr::get(getContext(), llDialect),
40054026
rewriter.getArrayAttr(opAttrs));
40064027

clang/test/CIR/Lowering/vec-cmp.cir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,15 @@ cir.func @vec_cmp(%0: !cir.vector<!s16i x 16>, %1: !cir.vector<!s16i x 16>) -> (
1414
// MLIR-NEXT: %{{[0-9]+}} = llvm.icmp "slt" %arg0, %arg1 : vector<16xi16>
1515
// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : vector<16xi1> to i16
1616
// MLIR-NEXT: llvm.return
17+
18+
cir.func @vec_cmp_zero(%0: !cir.vector<!u8i x 16>) -> () {
19+
%1 = cir.const #cir.zero : !cir.vector<!u8i x 16>
20+
%2 = cir.vec.cmp(lt, %0, %1) : !cir.vector<!u8i x 16>, !cir.vector<!cir.int<u, 1> x 16>
21+
%3 = cir.cast(bitcast, %2 : !cir.vector<!cir.int<u, 1> x 16>), !cir.int<u, 16>
22+
23+
cir.return
24+
}
25+
26+
// MLIR: llvm.func @vec_cmp_zero
27+
// MLIR: %{{[0-9]+}} = llvm.icmp "slt" %arg0, %{{[0-9]+}} : vector<16xi8>
28+
// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : vector<16xi1> to i16

0 commit comments

Comments
 (0)