diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index b9b5b5823d780..39727b1613653 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -54,6 +54,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" @@ -2206,6 +2207,23 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { std::swap(CmpLHS, CmpRHS); } + // Optimize AMDGPU ballot pattern: assume(ballot(cmp) == -1) or + // assume(ballot(cmp) == exec_mask). This implies cmp is true on all + // active lanes and hence can be replaced with true. + if (isa(CmpLHS) && isa(CmpRHS)) { + auto *IntrCall = cast(CmpLHS); + // Check if CmpLHS is a ballot intrinsic + if (IntrCall->getIntrinsicID() == + Intrinsic::AMDGCNIntrinsics::amdgcn_ballot) { + Value *BallotArg = IntrCall->getArgOperand(0); + if (BallotArg->getType()->isIntegerTy(1) && + (match(CmpRHS, m_AllOnes()) || !isa(CmpRHS))) { + CmpLHS = BallotArg; + CmpRHS = ConstantInt::getTrue(BallotArg->getType()); + } + } + } + // Handle degenerate case where we either haven't pruned a dead path or a // removed a trivial assume yet. if (isa(CmpLHS) && isa(CmpRHS)) diff --git a/llvm/test/Transforms/GVN/assume-equal.ll b/llvm/test/Transforms/GVN/assume-equal.ll index 0c922daf82b32..3eb10f4ab99e7 100644 --- a/llvm/test/Transforms/GVN/assume-equal.ll +++ b/llvm/test/Transforms/GVN/assume-equal.ll @@ -387,6 +387,35 @@ define i8 @assume_ptr_eq_same_prov(ptr %p, i64 %x) { ret i8 %v } +; Test AMDGPU ballot pattern optimization +; assume(ballot(cmp) == -1) means cmp is true on all active lanes +; so uses of cmp can be replaced with true +define void @assume_ballot_uniform(i32 %x) { +; CHECK-LABEL: @assume_ballot_uniform( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0 +; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]]) +; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1 +; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]]) +; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]] +; CHECK: foo: +; CHECK-NEXT: ret void +; CHECK: bar: +; CHECK-NEXT: ret void +; + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %all = icmp eq i64 %ballot, -1 + call void @llvm.assume(i1 %all) + br i1 %cmp, label %foo, label %bar + +foo: + ret void + +bar: + ret void +} + +declare i64 @llvm.amdgcn.ballot.i64(i1) declare noalias ptr @_Znwm(i64) declare void @_ZN1AC1Ev(ptr) declare void @llvm.assume(i1)