diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 010b77548c152..2b353d5fd69d6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2761,6 +2761,75 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC, return nullptr; } +static Instruction *foldSelectWithClampedShift(SelectInst &SI, + InstCombinerImpl &IC, + IRBuilderBase &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + Type *SelType = SI.getType(); + uint64_t BW = SelType->getScalarSizeInBits(); + + auto MatchClampedShift = [&](Value *V, Value *Amt) -> BinaryOperator * { + Value *X, *Limit; + Instruction *I; + + // Fold (select (icmp_ugt A, BW-1), TrueVal, (shift X, (umin A, C))) + // --> (select (icmp_ugt A, BW-1), TrueVal, (shift X, A)) + // Fold (select (icmp_ult A, BW), (shift X, (umin A, C)), FalseVal) + // --> (select (icmp_ult A, BW), (shift X, A), FalseVal) + // iff C >= BW-1 + if (match(V, m_OneUse(m_Shift(m_Value(X), + m_UMin(m_Specific(Amt), m_Value(Limit)))))) { + KnownBits KnownLimit = IC.computeKnownBits(Limit, 0, &SI); + if (KnownLimit.getMinValue().uge(BW - 1)) + return cast(V); + } + + // Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), FalseVal) + // --> (select (icmp_ugt A, BW-1), (shift X, A), FalseVal) + // Fold (select (icmp_ult A, BW), (shift X, (and A, C)), FalseVal) + // --> (select (icmp_ult A, BW), (shift X, A), FalseVal) + // iff Pow2 element width we just demand the amt mask bits. + if (isPowerOf2_64(BW) && + match(V, m_OneUse(m_Shift(m_Value(X), m_Instruction(I))))) { + KnownBits Known(BW); + APInt DemandedBits = APInt::getLowBitsSet(BW, Log2_64(BW)); + if (Value *NewAmt = IC.SimplifyMultipleUseDemandedBits( + I, DemandedBits, Known, /*Depth=*/0, + IC.getSimplifyQuery().getWithInstruction(I))) + return Amt == NewAmt ? cast(V) : nullptr; + } + + return nullptr; + }; + + Value *Amt; + if (match(CondVal, m_SpecificICmp(ICmpInst::ICMP_UGT, m_Value(Amt), + m_SpecificInt(BW - 1)))) { + if (BinaryOperator *ShiftI = MatchClampedShift(FalseVal, Amt)) { + Amt = Builder.CreateFreeze(Amt); + return SelectInst::Create( + Builder.CreateICmpUGT(Amt, cast(CondVal)->getOperand(1)), + TrueVal, + Builder.CreateBinOp(ShiftI->getOpcode(), ShiftI->getOperand(0), Amt)); + } + } + + if (match(CondVal, m_SpecificICmp(ICmpInst::ICMP_ULT, m_Value(Amt), + m_SpecificInt(BW)))) { + if (BinaryOperator *ShiftI = MatchClampedShift(TrueVal, Amt)) { + Amt = Builder.CreateFreeze(Amt); + return SelectInst::Create( + Builder.CreateICmpULT(Amt, cast(CondVal)->getOperand(1)), + Builder.CreateBinOp(ShiftI->getOpcode(), ShiftI->getOperand(0), Amt), + FalseVal); + } + } + + return nullptr; +} + static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) { FreezeInst *FI = dyn_cast(Sel.getCondition()); if (!FI) @@ -3871,6 +3940,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *I = foldSelectExtConst(SI)) return I; + if (Instruction *I = foldSelectWithClampedShift(SI, *this, Builder)) + return I; + if (Instruction *I = foldSelectWithSRem(SI, *this, Builder)) return I; diff --git a/llvm/test/Transforms/InstCombine/select-shift-clamp.ll b/llvm/test/Transforms/InstCombine/select-shift-clamp.ll new file mode 100644 index 0000000000000..31c60b58ed65e --- /dev/null +++ b/llvm/test/Transforms/InstCombine/select-shift-clamp.ll @@ -0,0 +1,236 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -passes=instcombine < %s | FileCheck %s + +declare void @use_i17(i17) +declare void @use_i32(i32) + +; Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), FalseVal) +; --> (select (icmp_ugt A, BW-1), (shift X, A), FalseVal) +; Fold (select (icmp_ult A, BW), (shift X, (and A, C)), FalseVal) +; --> (select (icmp_ult A, BW), (shift X, A), FalseVal) +; iff Pow2 element width and C masks all amt bits. + +define i32 @select_ult_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) { +; CHECK-LABEL: @select_ult_shl_clamp_and_i32( +; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]] +; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1]], 32 +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp ult i32 %a1, 32 + %m = and i32 %a1, 31 + %s = shl i32 %a0, %m + %r = select i1 %c, i32 %s, i32 %a2 + ret i32 %r +} + +define i32 @select_ule_ashr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) { +; CHECK-LABEL: @select_ule_ashr_clamp_and_i32( +; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = ashr i32 [[A0:%.*]], [[A1]] +; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1]], 32 +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp ule i32 %a1, 31 + %m = and i32 %a1, 127 + %s = ashr i32 %a0, %m + %r = select i1 %c, i32 %s, i32 %a2 + ret i32 %r +} + +define i32 @select_ugt_lshr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) { +; CHECK-LABEL: @select_ugt_lshr_clamp_and_i32( +; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[A0:%.*]], [[A1]] +; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1]], 31 +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[TMP1]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp ugt i32 %a1, 31 + %m = and i32 %a1, 31 + %s = lshr i32 %a0, %m + %r = select i1 %c, i32 %a2, i32 %s + ret i32 %r +} + +define i32 @select_uge_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) { +; CHECK-LABEL: @select_uge_shl_clamp_and_i32( +; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]] +; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1]], 31 +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[TMP1]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp uge i32 %a1, 32 + %m = and i32 %a1, 63 + %s = shl i32 %a0, %m + %r = select i1 %c, i32 %a2, i32 %s + ret i32 %r +} + +; negative test - multiuse +define i32 @select_ule_ashr_clamp_and_i32_multiuse(i32 %a0, i32 %a1, i32 %a2) { +; CHECK-LABEL: @select_ule_ashr_clamp_and_i32_multiuse( +; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32 +; CHECK-NEXT: [[M:%.*]] = and i32 [[A1]], 127 +; CHECK-NEXT: [[S:%.*]] = ashr i32 [[A0:%.*]], [[M]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[S]], i32 [[A2:%.*]] +; CHECK-NEXT: call void @use_i32(i32 [[S]]) +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp ule i32 %a1, 31 + %m = and i32 %a1, 127 + %s = ashr i32 %a0, %m + %r = select i1 %c, i32 %s, i32 %a2 + call void @use_i32(i32 %s) + ret i32 %r +} + +; negative test - mask doesn't cover all legal amount bit +define i32 @select_ult_shl_clamp_and_i32_badmask(i32 %a0, i32 %a1, i32 %a2) { +; CHECK-LABEL: @select_ult_shl_clamp_and_i32_badmask( +; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32 +; CHECK-NEXT: [[M:%.*]] = and i32 [[A1]], 28 +; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[M]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[S]], i32 [[A2:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp ult i32 %a1, 32 + %m = and i32 %a1, 28 + %s = shl i32 %a0, %m + %r = select i1 %c, i32 %s, i32 %a2 + ret i32 %r +} + +; negative test - non-pow2 +define i17 @select_uge_lshr_clamp_and_i17_nonpow2(i17 %a0, i17 %a1, i17 %a2) { +; CHECK-LABEL: @select_uge_lshr_clamp_and_i17_nonpow2( +; CHECK-NEXT: [[C:%.*]] = icmp ugt i17 [[A1:%.*]], 16 +; CHECK-NEXT: [[M:%.*]] = and i17 [[A1]], 255 +; CHECK-NEXT: [[S:%.*]] = lshr i17 [[A0:%.*]], [[M]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i17 [[A2:%.*]], i17 [[S]] +; CHECK-NEXT: ret i17 [[R]] +; + %c = icmp uge i17 %a1, 17 + %m = and i17 %a1, 255 + %s = lshr i17 %a0, %m + %r = select i1 %c, i17 %a2, i17 %s + ret i17 %r +} + +; Fold (select (icmp_ugt A, BW-1), TrueVal, (shift X, (umin A, C))) +; --> (select (icmp_ugt A, BW-1), TrueVal, (shift X, A)) +; Fold (select (icmp_ult A, BW), (shift X, (umin A, C)), FalseVal) +; --> (select (icmp_ult A, BW), (shift X, A), FalseVal) +; iff C >= BW-1 + +define i32 @select_ult_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) { +; CHECK-LABEL: @select_ult_shl_clamp_umin_i32( +; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]] +; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1]], 32 +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp ult i32 %a1, 32 + %m = call i32 @llvm.umin.i32(i32 %a1, i32 31) + %s = shl i32 %a0, %m + %r = select i1 %c, i32 %s, i32 %a2 + ret i32 %r +} + +define i17 @select_ule_ashr_clamp_umin_i17(i17 %a0, i17 %a1, i17 %a2) { +; CHECK-LABEL: @select_ule_ashr_clamp_umin_i17( +; CHECK-NEXT: [[A1:%.*]] = freeze i17 [[A3:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = ashr i17 [[A0:%.*]], [[A1]] +; CHECK-NEXT: [[C:%.*]] = icmp ult i17 [[A1]], 17 +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i17 [[TMP1]], i17 [[A2:%.*]] +; CHECK-NEXT: ret i17 [[R]] +; + %c = icmp ule i17 %a1, 16 + %m = call i17 @llvm.umin.i17(i17 %a1, i17 17) + %s = ashr i17 %a0, %m + %r = select i1 %c, i17 %s, i17 %a2 + ret i17 %r +} + +define i32 @select_ugt_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) { +; CHECK-LABEL: @select_ugt_shl_clamp_umin_i32( +; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]] +; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[A1]] +; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1]], 31 +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[S]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp ugt i32 %a1, 31 + %m = call i32 @llvm.umin.i32(i32 %a1, i32 128) + %s = shl i32 %a0, %m + %r = select i1 %c, i32 %a2, i32 %s + ret i32 %r +} + +define <2 x i32> @select_uge_lshr_clamp_umin_v2i32(<2 x i32> %a0, <2 x i32> %a1, <2 x i32> %a2) { +; CHECK-LABEL: @select_uge_lshr_clamp_umin_v2i32( +; CHECK-NEXT: [[A1:%.*]] = freeze <2 x i32> [[A3:%.*]] +; CHECK-NEXT: [[S:%.*]] = lshr <2 x i32> [[A0:%.*]], [[A1]] +; CHECK-NEXT: [[C:%.*]] = icmp ugt <2 x i32> [[A1]], splat (i32 31) +; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[C]], <2 x i32> [[A2:%.*]], <2 x i32> [[S]] +; CHECK-NEXT: ret <2 x i32> [[R]] +; + %c = icmp uge <2 x i32> %a1, + %m = call <2 x i32> @llvm.umin.v2i32(<2 x i32> %a1, <2 x i32> ) + %s = lshr <2 x i32> %a0, %m + %r = select <2 x i1> %c, <2 x i32> %a2, <2 x i32> %s + ret <2 x i32> %r +} + +; negative test - multiuse +define i32 @select_ugt_shl_clamp_umin_i32_multiuse(i32 %a0, i32 %a1, i32 %a2) { +; CHECK-LABEL: @select_ugt_shl_clamp_umin_i32_multiuse( +; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 32 +; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.umin.i32(i32 [[A1]], i32 128) +; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[M]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[S]] +; CHECK-NEXT: call void @use_i32(i32 [[S]]) +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp ugt i32 %a1, 32 + %m = call i32 @llvm.umin.i32(i32 %a1, i32 128) + %s = shl i32 %a0, %m + %r = select i1 %c, i32 %a2, i32 %s + call void @use_i32(i32 %s) + ret i32 %r +} + +; negative test - umin limit doesn't cover all legal amounts +define i17 @select_uge_lshr_clamp_umin_i17_badlimit(i17 %a0, i17 %a1, i17 %a2) { +; CHECK-LABEL: @select_uge_lshr_clamp_umin_i17_badlimit( +; CHECK-NEXT: [[C:%.*]] = icmp ugt i17 [[A1:%.*]], 15 +; CHECK-NEXT: [[M:%.*]] = call i17 @llvm.umin.i17(i17 [[A1]], i17 12) +; CHECK-NEXT: [[S:%.*]] = lshr i17 [[A0:%.*]], [[M]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i17 [[A2:%.*]], i17 [[S]] +; CHECK-NEXT: ret i17 [[R]] +; + %c = icmp uge i17 %a1, 16 + %m = call i17 @llvm.umin.i17(i17 %a1, i17 12) + %s = lshr i17 %a0, %m + %r = select i1 %c, i17 %a2, i17 %s + ret i17 %r +} + +define range(i64 0, -9223372036854775807) <4 x i64> @PR109888(<4 x i64> %0) { +; CHECK-LABEL: @PR109888( +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i64> [[TMP1:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = shl nuw <4 x i64> splat (i64 1), [[TMP0]] +; CHECK-NEXT: [[C:%.*]] = icmp ult <4 x i64> [[TMP0]], splat (i64 64) +; CHECK-NEXT: [[R:%.*]] = select <4 x i1> [[C]], <4 x i64> [[TMP2]], <4 x i64> zeroinitializer +; CHECK-NEXT: ret <4 x i64> [[R]] +; + %c = icmp ult <4 x i64> %0, + %m = and <4 x i64> %0, + %s = shl nuw <4 x i64> , %m + %r = select <4 x i1> %c, <4 x i64> %s, <4 x i64> zeroinitializer + ret <4 x i64> %r +}