Skip to content

Commit 08355f9

Browse files
Address feedback on the location of the opt
- Remove redundant const propagration (assume equality opt) from InstCombine. - Moved assume(ballot(cmp) == -1) optimization from InstCombine to GVN.
1 parent 40eeaff commit 08355f9

File tree

6 files changed

+54
-171
lines changed

6 files changed

+54
-171
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,12 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13221322
if (isa<PoisonValue>(Arg))
13231323
return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
13241324

1325-
// For Wave32 targets, convert i64 ballot to i32 ballot + zext
1325+
if (auto *Src = dyn_cast<ConstantInt>(Arg)) {
1326+
if (Src->isZero()) {
1327+
// amdgcn.ballot(i1 0) is zero.
1328+
return IC.replaceInstUsesWith(II, Constant::getNullValue(II.getType()));
1329+
}
1330+
}
13261331
if (ST->isWave32() && II.getType()->getIntegerBitWidth() == 64) {
13271332
// %b64 = call i64 ballot.i64(...)
13281333
// =>
@@ -1336,15 +1341,6 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13361341
Call->takeName(&II);
13371342
return IC.replaceInstUsesWith(II, Call);
13381343
}
1339-
1340-
if (auto *Src = dyn_cast<ConstantInt>(Arg)) {
1341-
if (Src->isZero()) {
1342-
// amdgcn.ballot(i1 0) is zero.
1343-
return IC.replaceInstUsesWith(II, Constant::getNullValue(II.getType()));
1344-
}
1345-
// Note: ballot(true) is NOT constant folded because the result depends
1346-
// on the active lanes in the wavefront, not just the condition value.
1347-
}
13481344
break;
13491345
}
13501346
case Intrinsic::amdgcn_wavefrontsize: {

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -3540,81 +3540,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35403540
}
35413541
}
35423542

3543-
// Basic assume equality optimization: assume(x == c) -> replace uses of x
3544-
// with c
3545-
if (auto *ICmp = dyn_cast<ICmpInst>(IIOperand)) {
3546-
if (ICmp->getPredicate() == ICmpInst::ICMP_EQ) {
3547-
Value *LHS = ICmp->getOperand(0);
3548-
Value *RHS = ICmp->getOperand(1);
3549-
Value *Variable = nullptr;
3550-
Constant *ConstantVal = nullptr;
3551-
3552-
if (auto *C = dyn_cast<Constant>(RHS)) {
3553-
Variable = LHS;
3554-
ConstantVal = C;
3555-
} else if (auto *C = dyn_cast<Constant>(LHS)) {
3556-
Variable = RHS;
3557-
ConstantVal = C;
3558-
}
3559-
3560-
if (Variable && ConstantVal && Variable->hasUseList()) {
3561-
SmallVector<Use *, 8> Uses;
3562-
for (Use &U : Variable->uses()) {
3563-
if (auto *UseInst = dyn_cast<Instruction>(U.getUser())) {
3564-
if (UseInst != ICmp &&
3565-
isValidAssumeForContext(II, UseInst, &DT)) {
3566-
Uses.push_back(&U);
3567-
}
3568-
}
3569-
}
3570-
3571-
for (Use *U : Uses) {
3572-
U->set(ConstantVal);
3573-
Worklist.pushValue(U->getUser());
3574-
}
3575-
3576-
if (!Uses.empty()) {
3577-
Worklist.pushValue(Variable);
3578-
}
3579-
}
3580-
}
3581-
}
3582-
3583-
// Optimize AMDGPU ballot patterns in assumes:
3584-
// assume(ballot(cmp) == -1) means cmp is true on all active lanes
3585-
// We can replace uses of cmp with true
3586-
Value *BallotInst;
3587-
if (match(IIOperand, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(BallotInst),
3588-
m_AllOnes()))) {
3589-
if (auto *IntrCall = dyn_cast<IntrinsicInst>(BallotInst)) {
3590-
if (IntrCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
3591-
Value *BallotArg = IntrCall->getArgOperand(0);
3592-
if (BallotArg->getType()->isIntegerTy(1) && BallotArg->hasUseList()) {
3593-
// Find uses and replace with true
3594-
SmallVector<Use *, 8> Uses;
3595-
for (Use &U : BallotArg->uses()) {
3596-
if (auto *UseInst = dyn_cast<Instruction>(U.getUser())) {
3597-
if (UseInst != IntrCall &&
3598-
isValidAssumeForContext(II, UseInst, &DT)) {
3599-
Uses.push_back(&U);
3600-
}
3601-
}
3602-
}
3603-
3604-
// Replace uses with true
3605-
for (Use *U : Uses) {
3606-
U->set(ConstantInt::getTrue(BallotArg->getType()));
3607-
Worklist.pushValue(U->getUser());
3608-
}
3609-
3610-
if (!Uses.empty()) {
3611-
Worklist.pushValue(BallotArg);
3612-
}
3613-
}
3614-
}
3615-
}
3616-
}
3617-
36183543
// If there is a dominating assume with the same condition as this one,
36193544
// then this one is redundant, and should be removed.
36203545
KnownBits Known(1);

llvm/lib/Transforms/Scalar/GVN.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
#include "llvm/IR/Instruction.h"
5555
#include "llvm/IR/Instructions.h"
5656
#include "llvm/IR/IntrinsicInst.h"
57+
#include "llvm/IR/IntrinsicsAMDGPU.h"
5758
#include "llvm/IR/LLVMContext.h"
5859
#include "llvm/IR/Metadata.h"
5960
#include "llvm/IR/Module.h"
@@ -2206,6 +2207,23 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
22062207
std::swap(CmpLHS, CmpRHS);
22072208
}
22082209

2210+
// Optimize AMDGPU ballot pattern: assume(ballot(cmp) == -1) or
2211+
// assume(ballot(cmp) == exec_mask). This implies cmp is true on all
2212+
// active lanes and hence can be replaced with true.
2213+
if (isa<IntrinsicInst>(CmpLHS) && isa<Constant>(CmpRHS)) {
2214+
auto *IntrCall = cast<IntrinsicInst>(CmpLHS);
2215+
// Check if CmpLHS is a ballot intrinsic
2216+
if (IntrCall->getIntrinsicID() ==
2217+
Intrinsic::AMDGCNIntrinsics::amdgcn_ballot) {
2218+
Value *BallotArg = IntrCall->getArgOperand(0);
2219+
if (BallotArg->getType()->isIntegerTy(1) &&
2220+
(match(CmpRHS, m_AllOnes()) || !isa<Constant>(CmpRHS))) {
2221+
CmpLHS = BallotArg;
2222+
CmpRHS = ConstantInt::getTrue(BallotArg->getType());
2223+
}
2224+
}
2225+
}
2226+
22092227
// Handle degenerate case where we either haven't pruned a dead path or a
22102228
// removed a trivial assume yet.
22112229
if (isa<Constant>(CmpLHS) && isa<Constant>(CmpRHS))

llvm/test/Transforms/GVN/assume-equal.ll

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,35 @@ define i8 @assume_ptr_eq_same_prov(ptr %p, i64 %x) {
387387
ret i8 %v
388388
}
389389

390+
; Test AMDGPU ballot pattern optimization
391+
; assume(ballot(cmp) == -1) means cmp is true on all active lanes
392+
; so uses of cmp can be replaced with true
393+
define void @assume_ballot_uniform(i32 %x) {
394+
; CHECK-LABEL: @assume_ballot_uniform(
395+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
396+
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
397+
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1
398+
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
399+
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
400+
; CHECK: foo:
401+
; CHECK-NEXT: ret void
402+
; CHECK: bar:
403+
; CHECK-NEXT: ret void
404+
;
405+
%cmp = icmp eq i32 %x, 0
406+
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
407+
%all = icmp eq i64 %ballot, -1
408+
call void @llvm.assume(i1 %all)
409+
br i1 %cmp, label %foo, label %bar
410+
411+
foo:
412+
ret void
413+
414+
bar:
415+
ret void
416+
}
417+
418+
declare i64 @llvm.amdgcn.ballot.i64(i1)
390419
declare noalias ptr @_Znwm(i64)
391420
declare void @_ZN1AC1Ev(ptr)
392421
declare void @llvm.assume(i1)

llvm/test/Transforms/InstCombine/amdgpu-ballot-constant-fold.ll

Lines changed: 0 additions & 56 deletions
This file was deleted.

llvm/test/Transforms/InstCombine/assume.ll

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ define i32 @simple(i32 %a) #1 {
8282
; CHECK-LABEL: @simple(
8383
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], 4
8484
; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]])
85-
; CHECK-NEXT: ret i32 4
85+
; CHECK-NEXT: ret i32 [[A]]
8686
;
8787
%cmp = icmp eq i32 %a, 4
8888
tail call void @llvm.assume(i1 %cmp)
@@ -1034,35 +1034,6 @@ define i1 @neg_assume_trunc_eq_one(i8 %x) {
10341034
ret i1 %q
10351035
}
10361036

1037-
; Test AMDGPU ballot pattern optimization
1038-
; assume(ballot(cmp) == -1) means cmp is true on all active lanes
1039-
; so uses of cmp can be replaced with true
1040-
define void @assume_ballot_uniform(i32 %x) {
1041-
; CHECK-LABEL: @assume_ballot_uniform(
1042-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
1043-
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
1044-
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1
1045-
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
1046-
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
1047-
; CHECK: foo:
1048-
; CHECK-NEXT: ret void
1049-
; CHECK: bar:
1050-
; CHECK-NEXT: ret void
1051-
;
1052-
%cmp = icmp eq i32 %x, 0
1053-
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
1054-
%all = icmp eq i64 %ballot, -1
1055-
call void @llvm.assume(i1 %all)
1056-
br i1 %cmp, label %foo, label %bar
1057-
1058-
foo:
1059-
ret void
1060-
1061-
bar:
1062-
ret void
1063-
}
1064-
1065-
declare i64 @llvm.amdgcn.ballot.i64(i1)
10661037
declare void @use(i1)
10671038
declare void @llvm.dbg.value(metadata, metadata, metadata)
10681039

0 commit comments

Comments
 (0)