From e5e18d4010e76fdeb709a9c5c98d877ef1fd57a3 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Thu, 17 Oct 2024 11:10:02 -0400 Subject: [PATCH 1/4] [OPT] Search whole BB for convergence token. The spec for llvm.experimental.convergence.entry says that is must be in the entry block for a function, and must preceed any other convergent operation. It does not have to be the first instruction in the entry block. Inlining assumes that the call to llvm.experimental.convergence.entry will be the first instruction after any phi instructions. This commit modifies inlining to search the entire block for the call. --- llvm/lib/Transforms/Utils/InlineFunction.cpp | 37 +++++++++++-------- .../Transforms/Inline/convergence-inline.ll | 24 ++++++++++++ 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 13eb588e46de8..1a0b77bbfdb8c 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -180,6 +180,19 @@ namespace { } }; + IntrinsicInst *getConevrgenceEntryIfAny(BasicBlock &BB) { + auto *I = BB.getFirstNonPHI(); + while (I) { + if (auto *IntrinsicCall = dyn_cast(I)) { + if (IntrinsicCall->getIntrinsicID() == + Intrinsic::experimental_convergence_entry) { + return IntrinsicCall; + } + } + I = I->getNextNode(); + } + return nullptr; + } } // end anonymous namespace /// Get or create a target for the branch from ResumeInsts. @@ -2438,15 +2451,10 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // fully implements convergence control tokens, there is no mixing of // controlled and uncontrolled convergent operations in the whole program. if (CB.isConvergent()) { - auto *I = CalledFunc->getEntryBlock().getFirstNonPHI(); - if (auto *IntrinsicCall = dyn_cast(I)) { - if (IntrinsicCall->getIntrinsicID() == - Intrinsic::experimental_convergence_entry) { - if (!ConvergenceControlToken) { - return InlineResult::failure( - "convergent call needs convergencectrl operand"); - } - } + auto *I = getConevrgenceEntryIfAny(CalledFunc->getEntryBlock()); + if (I && !ConvergenceControlToken) { + return InlineResult::failure( + "convergent call needs convergencectrl operand"); } } @@ -2737,13 +2745,10 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, } if (ConvergenceControlToken) { - auto *I = FirstNewBlock->getFirstNonPHI(); - if (auto *IntrinsicCall = dyn_cast(I)) { - if (IntrinsicCall->getIntrinsicID() == - Intrinsic::experimental_convergence_entry) { - IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken); - IntrinsicCall->eraseFromParent(); - } + auto *IntrinsicCall = getConevrgenceEntryIfAny(*FirstNewBlock); + if (IntrinsicCall) { + IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken); + IntrinsicCall->eraseFromParent(); } } diff --git a/llvm/test/Transforms/Inline/convergence-inline.ll b/llvm/test/Transforms/Inline/convergence-inline.ll index 8c67e6a59b7db..4996a2376be63 100644 --- a/llvm/test/Transforms/Inline/convergence-inline.ll +++ b/llvm/test/Transforms/Inline/convergence-inline.ll @@ -185,6 +185,30 @@ define void @test_two_calls() convergent { ret void } +define i32 @token_not_first(i32 %x) convergent alwaysinline { +; CHECK-LABEL: @token_not_first( +; CHECK-NEXT: {{%.*}} = alloca ptr, align 8 +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: [[Y:%.*]] = call i32 @g(i32 [[X:%.*]]) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret i32 [[Y]] +; + %p = alloca ptr, align 8 + %token = call token @llvm.experimental.convergence.entry() + %y = call i32 @g(i32 %x) [ "convergencectrl"(token %token) ] + ret i32 %y +} + +define void @test_token_not_first() convergent { +; CHECK-LABEL: @test_token_not_first( +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: {{%.*}} = call i32 @g(i32 23) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret void +; + %token = call token @llvm.experimental.convergence.entry() + %x = call i32 @token_not_first(i32 23) [ "convergencectrl"(token %token) ] + ret void +} + declare void @f(i32) convergent declare i32 @g(i32) convergent From ff9b6366055c7990165e53be8b6296e5a26cd392 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Mon, 28 Oct 2024 11:29:18 -0400 Subject: [PATCH 2/4] Small fixes based on code review. --- llvm/lib/Transforms/Utils/InlineFunction.cpp | 27 ++++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 1a0b77bbfdb8c..9c7626fd79a17 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -179,21 +179,20 @@ namespace { } } }; +} // end anonymous namespace - IntrinsicInst *getConevrgenceEntryIfAny(BasicBlock &BB) { - auto *I = BB.getFirstNonPHI(); - while (I) { - if (auto *IntrinsicCall = dyn_cast(I)) { - if (IntrinsicCall->getIntrinsicID() == - Intrinsic::experimental_convergence_entry) { - return IntrinsicCall; - } +static IntrinsicInst *getConvergenceEntry(BasicBlock &BB) { + auto *I = BB.getFirstNonPHI(); + while (I) { + if (auto *IntrinsicCall = dyn_cast(I)) { + if (IntrinsicCall->isEntry()) { + return IntrinsicCall; } - I = I->getNextNode(); } - return nullptr; + I = I->getNextNode(); } -} // end anonymous namespace + return nullptr; +} /// Get or create a target for the branch from ResumeInsts. BasicBlock *LandingPadInliningInfo::getInnerResumeDest() { @@ -2451,8 +2450,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // fully implements convergence control tokens, there is no mixing of // controlled and uncontrolled convergent operations in the whole program. if (CB.isConvergent()) { - auto *I = getConevrgenceEntryIfAny(CalledFunc->getEntryBlock()); - if (I && !ConvergenceControlToken) { + if (!ConvergenceControlToken && + getConvergenceEntry(CalledFunc->getEntryBlock())) { return InlineResult::failure( "convergent call needs convergencectrl operand"); } @@ -2745,7 +2744,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, } if (ConvergenceControlToken) { - auto *IntrinsicCall = getConevrgenceEntryIfAny(*FirstNewBlock); + auto *IntrinsicCall = getConvergenceEntry(*FirstNewBlock); if (IntrinsicCall) { IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken); IntrinsicCall->eraseFromParent(); From a2edc287391962b9e9760eb7f2017fd1f3f0f3ce Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Mon, 28 Oct 2024 11:33:42 -0400 Subject: [PATCH 3/4] Fix format --- llvm/lib/Transforms/Utils/InlineFunction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 9c7626fd79a17..9c02fbc9cda7f 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -181,7 +181,7 @@ namespace { }; } // end anonymous namespace -static IntrinsicInst *getConvergenceEntry(BasicBlock &BB) { +static IntrinsicInst *getConvergenceEntry(BasicBlock &BB) { auto *I = BB.getFirstNonPHI(); while (I) { if (auto *IntrinsicCall = dyn_cast(I)) { From 0d8de066fcaa2f265aaf6fab6d43e153f8357922 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Wed, 30 Oct 2024 09:35:40 -0400 Subject: [PATCH 4/4] Replace use of 'auto' --- llvm/lib/Transforms/Utils/InlineFunction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 7035235374f0d..a27cb4dd219c3 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -2802,7 +2802,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, } if (ConvergenceControlToken) { - auto *IntrinsicCall = getConvergenceEntry(*FirstNewBlock); + IntrinsicInst *IntrinsicCall = getConvergenceEntry(*FirstNewBlock); if (IntrinsicCall) { IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken); IntrinsicCall->eraseFromParent();