Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions llvm/lib/Transforms/Scalar/GVN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -2206,6 +2207,23 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
std::swap(CmpLHS, CmpRHS);
}

// Optimize AMDGPU ballot pattern: assume(ballot(cmp) == -1) or
// assume(ballot(cmp) == exec_mask). This implies cmp is true on all
// active lanes and hence can be replaced with true.
if (isa<IntrinsicInst>(CmpLHS) && isa<Constant>(CmpRHS)) {
auto *IntrCall = cast<IntrinsicInst>(CmpLHS);
// Check if CmpLHS is a ballot intrinsic
if (IntrCall->getIntrinsicID() ==
Intrinsic::AMDGCNIntrinsics::amdgcn_ballot) {
Value *BallotArg = IntrCall->getArgOperand(0);
if (BallotArg->getType()->isIntegerTy(1) &&
(match(CmpRHS, m_AllOnes()) || !isa<Constant>(CmpRHS))) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second part here, !isa<Constant>(CmpRHS) ... is this meant to be execmask? What if it is simply a value that the programmer knows will match the current state of cmp across active threads? For example:

unsigned threadmask = foo(); // where a bit is 1 for whatever reason
if (divergent_condition) {
  cmp = bar();
  assume(ballot(cmp) == threadmask);
}

CmpLHS = BallotArg;
CmpRHS = ConstantInt::getTrue(BallotArg->getType());
}
}
}

// Handle degenerate case where we either haven't pruned a dead path or a
// removed a trivial assume yet.
if (isa<Constant>(CmpLHS) && isa<Constant>(CmpRHS))
Expand Down
29 changes: 29 additions & 0 deletions llvm/test/Transforms/GVN/assume-equal.ll
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,35 @@ define i8 @assume_ptr_eq_same_prov(ptr %p, i64 %x) {
ret i8 %v
}

; Test AMDGPU ballot pattern optimization
; assume(ballot(cmp) == -1) means cmp is true on all active lanes
; so uses of cmp can be replaced with true
define void @assume_ballot_uniform(i32 %x) {
; CHECK-LABEL: @assume_ballot_uniform(
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
; CHECK: foo:
; CHECK-NEXT: ret void
; CHECK: bar:
; CHECK-NEXT: ret void
;
%cmp = icmp eq i32 %x, 0
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
%all = icmp eq i64 %ballot, -1
call void @llvm.assume(i1 %all)
br i1 %cmp, label %foo, label %bar

foo:
ret void

bar:
ret void
}

declare i64 @llvm.amdgcn.ballot.i64(i1)
declare noalias ptr @_Znwm(i64)
declare void @_ZN1AC1Ev(ptr)
declare void @llvm.assume(i1)
Expand Down