diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h index 44141844f42f4..35d21aa1d66d9 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -325,6 +325,7 @@ ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI, std::optional> ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2, + unsigned DstScalarSizeInBits, unsigned ExtOp, const MachineRegisterInfo &MRI); /// Test if the given value is known to have exactly one bit set. This differs diff --git a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp index bf8e847011d7c..10c72641ce2df 100644 --- a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp @@ -189,10 +189,12 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc, assert(SrcOps.size() == 3 && "Invalid sources"); assert(DstOps.size() == 1 && "Invalid dsts"); LLT SrcTy = SrcOps[1].getLLTTy(*getMRI()); + LLT DstTy = DstOps[0].getLLTTy(*getMRI()); + auto BoolExtOp = getBoolExtOp(SrcTy.isVector(), false); - if (std::optional> Cst = - ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(), - SrcOps[2].getReg(), *getMRI())) { + if (std::optional> Cst = ConstantFoldICmp( + SrcOps[0].getPredicate(), SrcOps[1].getReg(), SrcOps[2].getReg(), + DstTy.getScalarSizeInBits(), BoolExtOp, *getMRI())) { if (SrcTy.isVector()) return buildBuildVectorConstant(DstOps[0], *Cst); return buildConstant(DstOps[0], Cst->front()); diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index 223d69c362185..04e69f882bb43 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -1027,39 +1027,50 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI, std::optional> llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2, + unsigned DstScalarSizeInBits, unsigned ExtOp, const MachineRegisterInfo &MRI) { - LLT Ty = MRI.getType(Op1); - if (Ty != MRI.getType(Op2)) - return std::nullopt; + assert(ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT || + ExtOp == TargetOpcode::G_ANYEXT); - auto TryFoldScalar = [&MRI, Pred](Register LHS, - Register RHS) -> std::optional { - auto LHSCst = getIConstantVRegVal(LHS, MRI); + const LLT Ty = MRI.getType(Op1); + + auto GetICmpResultCst = [&](bool IsTrue) { + if (IsTrue) + return ExtOp == TargetOpcode::G_SEXT + ? APInt::getAllOnes(DstScalarSizeInBits) + : APInt::getOneBitSet(DstScalarSizeInBits, 0); + return APInt::getZero(DstScalarSizeInBits); + }; + + auto TryFoldScalar = [&](Register LHS, Register RHS) -> std::optional { auto RHSCst = getIConstantVRegVal(RHS, MRI); - if (!LHSCst || !RHSCst) + if (!RHSCst) + return std::nullopt; + auto LHSCst = getIConstantVRegVal(LHS, MRI); + if (!LHSCst) return std::nullopt; switch (Pred) { case CmpInst::Predicate::ICMP_EQ: - return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst)); + return GetICmpResultCst(LHSCst->eq(*RHSCst)); case CmpInst::Predicate::ICMP_NE: - return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst)); + return GetICmpResultCst(LHSCst->ne(*RHSCst)); case CmpInst::Predicate::ICMP_UGT: - return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst)); + return GetICmpResultCst(LHSCst->ugt(*RHSCst)); case CmpInst::Predicate::ICMP_UGE: - return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst)); + return GetICmpResultCst(LHSCst->uge(*RHSCst)); case CmpInst::Predicate::ICMP_ULT: - return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst)); + return GetICmpResultCst(LHSCst->ult(*RHSCst)); case CmpInst::Predicate::ICMP_ULE: - return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst)); + return GetICmpResultCst(LHSCst->ule(*RHSCst)); case CmpInst::Predicate::ICMP_SGT: - return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst)); + return GetICmpResultCst(LHSCst->sgt(*RHSCst)); case CmpInst::Predicate::ICMP_SGE: - return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst)); + return GetICmpResultCst(LHSCst->sge(*RHSCst)); case CmpInst::Predicate::ICMP_SLT: - return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst)); + return GetICmpResultCst(LHSCst->slt(*RHSCst)); case CmpInst::Predicate::ICMP_SLE: - return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst)); + return GetICmpResultCst(LHSCst->sle(*RHSCst)); default: return std::nullopt; } diff --git a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp index cd6e32311a9ee..7c29c9d419c08 100644 --- a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp @@ -500,6 +500,18 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) { EXPECT_TRUE(I->getOperand(1).getCImm()->getZExtValue()); } + { + auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, One); + EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT); + EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 1); + } + + { + auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, Two); + EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT); + EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 0); + } + LLT VecTy = LLT::fixed_vector(2, s32); LLT DstTy = LLT::fixed_vector(2, s1); auto Three = CSEB.buildConstant(s32, 3); @@ -508,6 +520,8 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) { auto OneTwo = CSEB.buildBuildVector(VecTy, {One.getReg(0), Two.getReg(0)}); auto TwoThree = CSEB.buildBuildVector(VecTy, {Two.getReg(0), Three.getReg(0)}); + auto OneThree = + CSEB.buildBuildVector(VecTy, {One.getReg(0), Three.getReg(0)}); auto MinusOneOne = CSEB.buildBuildVector(VecTy, {MinusOne.getReg(0), MinusOne.getReg(0)}); auto MinusOneTwo = @@ -547,6 +561,36 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) { // ICMP_SLE CSEB.buildICmp(CmpInst::Predicate::ICMP_SLE, DstTy, MinusOneTwo, MinusOneOne); + { + auto I = + CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, OneOne, TwoThree); + EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR); + const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI); + const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI); + EXPECT_EQ(HiCst.getSExtValue(), 0); + EXPECT_EQ(LoCst.getSExtValue(), 0); + } + + { + auto I = + CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, OneThree, TwoThree); + EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR); + const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI); + const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI); + EXPECT_EQ(HiCst.getSExtValue(), 0); + EXPECT_EQ(LoCst.getSExtValue(), -1); + } + + { + auto I = + CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, TwoThree, TwoThree); + EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR); + const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI); + const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI); + EXPECT_EQ(HiCst.getSExtValue(), -1); + EXPECT_EQ(LoCst.getSExtValue(), -1); + } + auto CheckStr = R"( ; CHECK: [[One:%[0-9]+]]:_(s32) = G_CONSTANT i32 1 ; CHECK: [[Two:%[0-9]+]]:_(s32) = G_CONSTANT i32 2 @@ -558,6 +602,7 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) { ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[One]]:_(s32) ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[Two]]:_(s32) ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[Two]]:_(s32), [[Three]]:_(s32) + ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[Three]]:_(s32) ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusOne]]:_(s32), [[MinusOne]]:_(s32) ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusOne]]:_(s32), [[MinusTwo]]:_(s32) ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusTwo]]:_(s32), [[MinusThree]]:_(s32)