@@ -1027,39 +1027,45 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
10271027
10281028std::optional<SmallVector<APInt>>
10291029llvm::ConstantFoldICmp (unsigned Pred, const Register Op1, const Register Op2,
1030+ unsigned DstScalarSizeInBits, unsigned ExtOp,
10301031 const MachineRegisterInfo &MRI) {
1031- LLT Ty = MRI.getType (Op1);
1032- if (Ty != MRI.getType (Op2))
1033- return std::nullopt ;
1032+ assert (ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT ||
1033+ ExtOp == TargetOpcode::G_ANYEXT);
1034+
1035+ const LLT Ty = MRI.getType (Op1);
10341036
1035- auto TryFoldScalar = [&MRI, Pred](Register LHS,
1036- Register RHS) -> std::optional<APInt> {
1037+ auto TryFoldScalar = [&](Register LHS, Register RHS) -> std::optional<APInt> {
10371038 auto LHSCst = getIConstantVRegVal (LHS, MRI);
10381039 auto RHSCst = getIConstantVRegVal (RHS, MRI);
10391040 if (!LHSCst || !RHSCst)
10401041 return std::nullopt ;
10411042
1043+ const APInt FalseCst = APInt::getZero (DstScalarSizeInBits);
1044+ const APInt TrueCst = (ExtOp == TargetOpcode::G_SEXT)
1045+ ? APInt::getAllOnes (DstScalarSizeInBits)
1046+ : APInt::getOneBitSet (DstScalarSizeInBits, 0 );
1047+
10421048 switch (Pred) {
10431049 case CmpInst::Predicate::ICMP_EQ:
1044- return APInt ( /* numBits= */ 1 , LHSCst->eq (*RHSCst)) ;
1050+ return LHSCst->eq (*RHSCst) ? TrueCst : FalseCst ;
10451051 case CmpInst::Predicate::ICMP_NE:
1046- return APInt ( /* numBits= */ 1 , LHSCst->ne (*RHSCst)) ;
1052+ return LHSCst->ne (*RHSCst) ? TrueCst : FalseCst ;
10471053 case CmpInst::Predicate::ICMP_UGT:
1048- return APInt ( /* numBits= */ 1 , LHSCst->ugt (*RHSCst)) ;
1054+ return LHSCst->ugt (*RHSCst) ? TrueCst : FalseCst ;
10491055 case CmpInst::Predicate::ICMP_UGE:
1050- return APInt ( /* numBits= */ 1 , LHSCst->uge (*RHSCst)) ;
1056+ return LHSCst->uge (*RHSCst) ? TrueCst : FalseCst ;
10511057 case CmpInst::Predicate::ICMP_ULT:
1052- return APInt ( /* numBits= */ 1 , LHSCst->ult (*RHSCst)) ;
1058+ return LHSCst->ult (*RHSCst) ? TrueCst : FalseCst ;
10531059 case CmpInst::Predicate::ICMP_ULE:
1054- return APInt ( /* numBits= */ 1 , LHSCst->ule (*RHSCst)) ;
1060+ return LHSCst->ule (*RHSCst) ? TrueCst : FalseCst ;
10551061 case CmpInst::Predicate::ICMP_SGT:
1056- return APInt ( /* numBits= */ 1 , LHSCst->sgt (*RHSCst)) ;
1062+ return LHSCst->sgt (*RHSCst) ? TrueCst : FalseCst ;
10571063 case CmpInst::Predicate::ICMP_SGE:
1058- return APInt ( /* numBits= */ 1 , LHSCst->sge (*RHSCst)) ;
1064+ return LHSCst->sge (*RHSCst) ? TrueCst : FalseCst ;
10591065 case CmpInst::Predicate::ICMP_SLT:
1060- return APInt ( /* numBits= */ 1 , LHSCst->slt (*RHSCst)) ;
1066+ return LHSCst->slt (*RHSCst) ? TrueCst : FalseCst ;
10611067 case CmpInst::Predicate::ICMP_SLE:
1062- return APInt ( /* numBits= */ 1 , LHSCst->sle (*RHSCst)) ;
1068+ return LHSCst->sle (*RHSCst) ? TrueCst : FalseCst ;
10631069 default :
10641070 return std::nullopt ;
10651071 }
0 commit comments