Skip to content

Commit 679f7b0

Browse files
Address feedback:1. Add the condition to make sure ballot width matches wave sixe
2. Modify logic to make the InferredCondValue non-optional 3. Simplified cmp arg match logic 4. Added missing braces and comments indents 5. Removed O2 checks from the test cases
1 parent b6cef74 commit 679f7b0

File tree

3 files changed

+20
-27
lines changed

3 files changed

+20
-27
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,9 +1353,10 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13531353
if (!isa<Instruction>(Arg))
13541354
break;
13551355

1356-
// Skip if ballot width < wave size (e.g., ballot.i32 on wave64).
1357-
if (ST->isWave64() && II.getType()->getIntegerBitWidth() == 32)
1358-
break; // ballot.i32 on wave64 captures only lanes [0:31]
1356+
// Skip if ballot width doesn't match wave size.
1357+
unsigned WavefrontSize = ST->getWavefrontSize();
1358+
if (WavefrontSize != II.getType()->getIntegerBitWidth())
1359+
break;
13591360

13601361
for (auto &AssumeVH : IC.getAssumptionCache().assumptionsFor(&II)) {
13611362
if (!AssumeVH)
@@ -1364,37 +1365,37 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13641365
auto *Assume = cast<AssumeInst>(AssumeVH);
13651366
Value *Cond = Assume->getArgOperand(0);
13661367

1367-
// Check if assume condition is an equality comparison.
1368-
auto *ICI = dyn_cast<ICmpInst>(Cond);
1368+
// Pattern match: assume(icmp eq ballot, CompareValue)
1369+
ICmpInst *ICI = dyn_cast<ICmpInst>(Cond);
13691370
if (!ICI || ICI->getPredicate() != ICmpInst::ICMP_EQ)
13701371
continue;
13711372

1372-
// Extract the ballot and the value being compared against it.
1373-
Value *LHS = ICI->getOperand(0), *RHS = ICI->getOperand(1);
1374-
Value *CompareValue = (LHS == &II) ? RHS : (RHS == &II) ? LHS : nullptr;
1375-
if (!CompareValue)
1373+
Value *CompareValue;
1374+
if (!match(ICI, m_c_ICmp(m_Specific(&II), m_Value(CompareValue))))
13761375
continue;
13771376

1378-
// Determine the constant value of the ballot's condition argument.
1379-
std::optional<bool> InferredCondValue;
1377+
// Determine the inferred value of the ballot's condition argument.
1378+
bool InferredCondValue;
13801379
if (auto *CI = dyn_cast<ConstantInt>(CompareValue)) {
1381-
// ballot(x) == -1 means all lanes have x = true.
1382-
if (CI->isMinusOne())
1380+
if (CI->isMinusOne()) {
1381+
// ballot(x) == -1 means all lanes have x = true.
13831382
InferredCondValue = true;
1384-
// ballot(x) == 0 means all lanes have x = false.
1385-
else if (CI->isZero())
1383+
} else if (CI->isZero()) {
1384+
// ballot(x) == 0 means all lanes have x = false.
13861385
InferredCondValue = false;
1386+
} else {
1387+
continue;
1388+
}
13871389
} else if (match(CompareValue,
13881390
m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One()))) {
13891391
// ballot(x) == ballot(true) means x = true (EXEC mask comparison).
13901392
InferredCondValue = true;
1391-
}
1392-
1393-
if (!InferredCondValue)
1393+
} else {
13941394
continue;
1395+
}
13951396

13961397
Constant *ReplacementValue =
1397-
ConstantInt::getBool(Arg->getContext(), *InferredCondValue);
1398+
ConstantInt::getBool(Arg->getContext(), InferredCondValue);
13981399

13991400
// Replace uses of the condition argument dominated by the assume.
14001401
bool Changed = false;

llvm/test/Transforms/InstCombine/AMDGPU/llvm.amdgcn.ballot-assume-wave32.ll

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -mtriple=amdgcn-amd-amdhsa -mattr=+wavefrontsize32 -passes=instcombine -S | FileCheck %s
3-
; RUN: opt < %s -mtriple=amdgcn-amd-amdhsa -mattr=+wavefrontsize32 -O2 -S | FileCheck %s --check-prefix=O2
43
;
54
; Wave32-specific tests for ballot-assume optimizations.
65
; - ballot.i32 should optimize (captures all 32 lanes)
@@ -27,9 +26,6 @@ define amdgpu_kernel void @wave32_ballot_i32_all_lanes(i32 %x, ptr addrspace(1)
2726
; CHECK: bar:
2827
; CHECK-NEXT: ret void
2928
;
30-
; O2-LABEL: @wave32_ballot_i32_all_lanes(
31-
; O2-NEXT: common.ret:
32-
; O2-NEXT: ret void
3329
;
3430
%cmp = icmp eq i32 %x, 0
3531
%ballot = call i32 @llvm.amdgcn.ballot.i32(i1 %cmp)

llvm/test/Transforms/InstCombine/AMDGPU/llvm.amdgcn.ballot-assume-wave64.ll

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -mtriple=amdgcn-amd-amdhsa -mattr=+wavefrontsize64 -passes=instcombine -S | FileCheck %s
3-
; RUN: opt < %s -mtriple=amdgcn-amd-amdhsa -mattr=+wavefrontsize64 -O2 -S | FileCheck %s --check-prefix=O2
43
;
54
; Wave64-specific tests for ballot-assume optimizations.
65
; - ballot.i64 should optimize (captures all 64 lanes)
@@ -27,9 +26,6 @@ define amdgpu_kernel void @wave64_ballot_i64_all_lanes(i32 %x, ptr addrspace(1)
2726
; CHECK: bar:
2827
; CHECK-NEXT: ret void
2928
;
30-
; O2-LABEL: @wave64_ballot_i64_all_lanes(
31-
; O2-NEXT: common.ret:
32-
; O2-NEXT: ret void
3329
;
3430
%cmp = icmp eq i32 %x, 0
3531
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)

0 commit comments

Comments
 (0)