Skip to content
Merged
77 changes: 77 additions & 0 deletions llvm/lib/Analysis/InlineCost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/DomConditionCache.h"
#include "llvm/Analysis/EphemeralValuesCache.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/LoopInfo.h"
Expand Down Expand Up @@ -262,6 +263,8 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
// Cache the DataLayout since we use it a lot.
const DataLayout &DL;

DominatorTree DT;

/// The OptimizationRemarkEmitter available for this compilation.
OptimizationRemarkEmitter *ORE;

Expand Down Expand Up @@ -444,6 +447,7 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
bool canFoldInboundsGEP(GetElementPtrInst &I);
bool accumulateGEPOffset(GEPOperator &GEP, APInt &Offset);
bool simplifyCallSite(Function *F, CallBase &Call);
bool simplifyCmpInst(Function *F, CmpInst &Cmp);
bool simplifyInstruction(Instruction &I);
bool simplifyIntrinsicCallIsConstant(CallBase &CB);
bool simplifyIntrinsicCallObjectSize(CallBase &CB);
Expand Down Expand Up @@ -1676,6 +1680,75 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) {
return isGEPFree(I);
}

// Simplify \p Cmp if RHS is const and we can ValueTrack LHS.
// This handles the case when the Cmp instruction is guarding a recursive call
// that will cause the Cmp to fail/succeed for the recursive call.
bool CallAnalyzer::simplifyCmpInst(Function *F, CmpInst &Cmp) {
// Bail out if LHS is not a function argument or RHS is NOT const:
if (!isa<Argument>(Cmp.getOperand(0)) || !isa<Constant>(Cmp.getOperand(1)))
return false;
auto *CmpOp = Cmp.getOperand(0);
// Iterate over the users of the function to check if it's a recursive
// function:
for (auto *U : F->users()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this iterating over the uses of the function? Shouldn't this be inspecting just the CandidateCall in particular?

It looks like this checks if there is any recursive call of the right form, even if it's not the call-site being analyzed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to get the recursive call that is guarded by the cmp instr that I'm analyzing.

I can do that by either top-down approach by finding the branch of the icmp and then getting the successors and iterate over all the instructions in the successors to find the recursive call, or the other way is bottom-up approach by finding any recursive call of the function and check its predecessor and so on until I get the icmp.
so, I think the bottom-up approach is better here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern here is that there may be multiple calls of the function, only one of which is the recursive call you are interested in. We will separate compute the cost for each of these call-sites -- but we will treat each of them as if they were the recursive call, and thus incorrectly assign them a lower cost. Instead, we should only do this when trying to inline the actual recursive call, as given by CandidateCall.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I got your point. Yes, I agree with you.
I think there should be a patch for ValueTracking changes, and another following patch for the 2 points you mentioned about the Inliner.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today, I will create a patch for the ValueTracking, and another patch for resolving your comments on InlineCost.

CallInst *Call = dyn_cast<CallInst>(U);
if (!Call || Call->getFunction() != F || Call->getCalledFunction() != F)
continue;
auto *CallBB = Call->getParent();
auto *Predecessor = CallBB->getSinglePredecessor();
// Only handle the case when the callsite has a single predecessor:
if (!Predecessor)
continue;

auto *Br = dyn_cast<BranchInst>(Predecessor->getTerminator());
if (!Br || Br->isUnconditional())
continue;
// Check if the Br condition is the same Cmp instr we are investigating:
if (Br->getCondition() != &Cmp)
continue;
// Check if there are any arg of the recursive callsite is affecting the cmp
// instr:
bool ArgFound = false;
Value *FuncArg = nullptr, *CallArg = nullptr;
for (unsigned ArgNum = 0;
ArgNum < F->arg_size() && ArgNum < Call->arg_size(); ArgNum++) {
FuncArg = F->getArg(ArgNum);
CallArg = Call->getArgOperand(ArgNum);
if (FuncArg == CmpOp && CallArg != CmpOp) {
ArgFound = true;
break;
}
}
if (!ArgFound)
continue;
// Now we have a recursive call that is guarded by a cmp instruction.
// Check if this cmp can be simplified:
SimplifyQuery SQ(DL, dyn_cast<Instruction>(CallArg));
DomConditionCache DC;
DC.registerBranch(Br);
SQ.DC = &DC;
if (DT.root_size() == 0) {
// Dominator tree was never constructed for any function yet.
DT.recalculate(*F);
} else if (DT.getRoot()->getParent() != F) {
// Dominator tree was constructed for a different function, recalculate
// it for the current function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fishy. Isn't the CallAnalyzer instantiated per call-site? How can we end up with different functions here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. The else-if statement should be removed.

DT.recalculate(*F);
}
SQ.DT = &DT;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to inject the condition via CondContext instead?

Copy link
Member Author

@hassnaaHamdi hassnaaHamdi May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the current logic at ValueTracking.cpp, then injecting the condition will not be useful.
But if the logic of ValueTracking.cpp::computeKnownFPClassFromContext got changed to check the CondContext and use directly computeKnownFPClassFromCond, similarly to the logic at computeKnownBitsFromContext, then injection will be useful.
Maybe I create a patch for the ValueTracking, and then a follow-up patch for applying your suggestion.

Value *SimplifiedInstruction = llvm::simplifyInstructionWithOperands(
cast<CmpInst>(&Cmp), {CallArg, Cmp.getOperand(1)}, SQ);
if (auto *ConstVal = dyn_cast_or_null<ConstantInt>(SimplifiedInstruction)) {
bool IsTrueSuccessor = CallBB == Br->getSuccessor(0);
SimplifiedValues[&Cmp] = ConstVal;
if (ConstVal->isOne())
return !IsTrueSuccessor;
return IsTrueSuccessor;
}
}
return false;
}

/// Simplify \p I if its operands are constants and update SimplifiedValues.
bool CallAnalyzer::simplifyInstruction(Instruction &I) {
SmallVector<Constant *> COps;
Expand Down Expand Up @@ -2060,6 +2133,10 @@ bool CallAnalyzer::visitCmpInst(CmpInst &I) {
if (simplifyInstruction(I))
return true;

// Try to handle comparison that can be simplified using ValueTracking.
if (simplifyCmpInst(I.getFunction(), I))
return true;

if (I.getOpcode() == Instruction::FCmp)
return false;

Expand Down
179 changes: 179 additions & 0 deletions llvm/test/Transforms/Inline/inline-recursive-fn.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes='inline,instcombine' < %s | FileCheck %s

define float @inline_rec_true_successor(float %x, float %scale) {
; CHECK-LABEL: define float @inline_rec_true_successor(
; CHECK-SAME: float [[X:%.*]], float [[SCALE:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp olt float [[X]], 0.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
; CHECK: [[COMMON_RET18:.*]]:
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[COMMON_RET18_OP_I:%.*]], %[[INLINE_REC_TRUE_SUCCESSOR_EXIT:.*]] ], [ [[MUL:%.*]], %[[IF_END]] ]
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: br i1 false, label %[[IF_THEN_I:.*]], label %[[IF_END_I:.*]]
; CHECK: [[IF_THEN_I]]:
; CHECK-NEXT: br label %[[INLINE_REC_TRUE_SUCCESSOR_EXIT]]
; CHECK: [[IF_END_I]]:
; CHECK-NEXT: [[FNEG:%.*]] = fneg float [[X]]
; CHECK-NEXT: [[MUL_I:%.*]] = fmul float [[SCALE]], [[FNEG]]
; CHECK-NEXT: br label %[[INLINE_REC_TRUE_SUCCESSOR_EXIT]]
; CHECK: [[INLINE_REC_TRUE_SUCCESSOR_EXIT]]:
; CHECK-NEXT: [[COMMON_RET18_OP_I]] = phi float [ poison, %[[IF_THEN_I]] ], [ [[MUL_I]], %[[IF_END_I]] ]
; CHECK-NEXT: br label %[[COMMON_RET18]]
; CHECK: [[IF_END]]:
; CHECK-NEXT: [[MUL]] = fmul float [[X]], [[SCALE]]
; CHECK-NEXT: br label %[[COMMON_RET18]]
;
entry:
%cmp = fcmp olt float %x, 0.000000e+00
br i1 %cmp, label %if.then, label %if.end

common.ret18: ; preds = %if.then, %if.end
%common.ret18.op = phi float [ %call, %if.then ], [ %mul, %if.end ]
ret float %common.ret18.op

if.then: ; preds = %entry
%fneg = fneg float %x
%call = tail call float @inline_rec_true_successor(float %fneg, float %scale)
br label %common.ret18

if.end: ; preds = %entry
%mul = fmul float %x, %scale
br label %common.ret18
}

define float @test_inline_rec_true_successor(float %x, float %scale) {
entry:
%res = tail call float @inline_rec_true_successor(float %x, float %scale)
ret float %res
}

; Same as previous test except that the recursive callsite is in the false successor
define float @inline_rec_false_successor(float %x, float %scale) {
; CHECK-LABEL: define float @inline_rec_false_successor(
; CHECK-SAME: float [[Y:%.*]], float [[SCALE:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp uge float [[Y]], 0.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
; CHECK: [[COMMON_RET18:.*]]:
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[MUL:%.*]], %[[IF_THEN]] ], [ [[COMMON_RET18_OP_I:%.*]], %[[INLINE_REC_FALSE_SUCCESSOR_EXIT:.*]] ]
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: [[MUL]] = fmul float [[Y]], [[SCALE]]
; CHECK-NEXT: br label %[[COMMON_RET18]]
; CHECK: [[IF_END]]:
; CHECK-NEXT: br i1 true, label %[[IF_THEN_I:.*]], label %[[IF_END_I:.*]]
; CHECK: [[IF_THEN_I]]:
; CHECK-NEXT: [[FNEG:%.*]] = fneg float [[Y]]
; CHECK-NEXT: [[MUL_I:%.*]] = fmul float [[SCALE]], [[FNEG]]
; CHECK-NEXT: br label %[[INLINE_REC_FALSE_SUCCESSOR_EXIT]]
; CHECK: [[IF_END_I]]:
; CHECK-NEXT: br label %[[INLINE_REC_FALSE_SUCCESSOR_EXIT]]
; CHECK: [[INLINE_REC_FALSE_SUCCESSOR_EXIT]]:
; CHECK-NEXT: [[COMMON_RET18_OP_I]] = phi float [ [[MUL_I]], %[[IF_THEN_I]] ], [ poison, %[[IF_END_I]] ]
; CHECK-NEXT: br label %[[COMMON_RET18]]
;
entry:
%cmp = fcmp uge float %x, 0.000000e+00
br i1 %cmp, label %if.then, label %if.end

common.ret18: ; preds = %if.then, %if.end
%common.ret18.op = phi float [ %mul, %if.then ], [ %call, %if.end ]
ret float %common.ret18.op

if.then: ; preds = %entry
%mul = fmul float %x, %scale
br label %common.ret18

if.end: ; preds = %entry
%fneg = fneg float %x
%call = tail call float @inline_rec_false_successor(float %fneg, float %scale)
br label %common.ret18
}

define float @test_inline_rec_false_successor(float %x, float %scale) {
entry:
%res = tail call float @inline_rec_false_successor(float %x, float %scale)
ret float %res
}

; Test when the BR has Value not cmp instruction
define float @inline_rec_no_cmp(i1 %flag, float %scale) {
; CHECK-LABEL: define float @inline_rec_no_cmp(
; CHECK-SAME: i1 [[FLAG:%.*]], float [[SCALE:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: br i1 [[FLAG]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: [[SUM:%.*]] = fadd float [[SCALE]], 5.000000e+00
; CHECK-NEXT: [[SUM1:%.*]] = fadd float [[SUM]], [[SCALE]]
; CHECK-NEXT: br label %[[COMMON_RET:.*]]
; CHECK: [[IF_END]]:
; CHECK-NEXT: [[SUM2:%.*]] = fadd float [[SCALE]], 5.000000e+00
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[COMMON_RET]]:
; CHECK-NEXT: [[COMMON_RET_RES:%.*]] = phi float [ [[SUM1]], %[[IF_THEN]] ], [ [[SUM2]], %[[IF_END]] ]
; CHECK-NEXT: ret float [[COMMON_RET_RES]]
;
entry:
br i1 %flag, label %if.then, label %if.end
if.then:
%res = tail call float @inline_rec_no_cmp(i1 false, float %scale)
%sum1 = fadd float %res, %scale
br label %common.ret
if.end:
%sum2 = fadd float %scale, 5.000000e+00
br label %common.ret
common.ret:
%common.ret.res = phi float [ %sum1, %if.then ], [ %sum2, %if.end ]
ret float %common.ret.res
}

define float @test_inline_rec_no_cmp(i1 %flag, float %scale) {
entry:
%res = tail call float @inline_rec_no_cmp(i1 %flag, float %scale)
ret float %res
}

define float @no_inline_rec(float %x, float %scale) {
; CHECK-LABEL: define float @no_inline_rec(
; CHECK-SAME: float [[Z:%.*]], float [[SCALE:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp olt float [[Z]], 5.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
; CHECK: [[COMMON_RET18:.*]]:
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[FNEG1:%.*]], %[[IF_THEN]] ], [ [[MUL:%.*]], %[[IF_END]] ]
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: [[FADD:%.*]] = fadd float [[Z]], 5.000000e+00
; CHECK-NEXT: [[CALL:%.*]] = tail call float @no_inline_rec(float [[FADD]], float [[SCALE]])
; CHECK-NEXT: [[FNEG1]] = fneg float [[CALL]]
; CHECK-NEXT: br label %[[COMMON_RET18]]
; CHECK: [[IF_END]]:
; CHECK-NEXT: [[MUL]] = fmul float [[Z]], [[SCALE]]
; CHECK-NEXT: br label %[[COMMON_RET18]]
;
entry:
%cmp = fcmp olt float %x, 5.000000e+00
br i1 %cmp, label %if.then, label %if.end

common.ret18: ; preds = %if.then, %if.end
%common.ret18.op = phi float [ %fneg1, %if.then ], [ %mul, %if.end ]
ret float %common.ret18.op

if.then: ; preds = %entry
%fadd = fadd float %x, 5.000000e+00
%call = tail call float @no_inline_rec(float %fadd, float %scale)
%fneg1 = fneg float %call
br label %common.ret18

if.end: ; preds = %entry
%mul = fmul float %x, %scale
br label %common.ret18
}

define float @test_no_inline(float %x, float %scale) {
entry:
%res = tail call float @no_inline_rec(float %x, float %scale)
ret float %res
}
Loading