Skip to content

Commit f5f8655

Browse files
Moved the assume based ballot folding logic to AMDGPUInstCombineIntrinsic.cpp
1 parent 4e075a6 commit f5f8655

File tree

3 files changed

+97
-106
lines changed

3 files changed

+97
-106
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "AMDGPUTargetTransformInfo.h"
1919
#include "GCNSubtarget.h"
2020
#include "llvm/ADT/FloatingPointMode.h"
21+
#include "llvm/Analysis/AssumptionCache.h"
2122
#include "llvm/IR/Dominators.h"
2223
#include "llvm/IR/IntrinsicsAMDGPU.h"
2324
#include "llvm/Transforms/InstCombine/InstCombiner.h"
@@ -1341,6 +1342,58 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13411342
Call->takeName(&II);
13421343
return IC.replaceInstUsesWith(II, Call);
13431344
}
1345+
1346+
// Fold ballot intrinsic based on llvm.assume hint about the result.
1347+
//
1348+
// assume(ballot(x) == ballot(i1 true)) -> x = true
1349+
// assume(ballot(x) == -1) -> x = true
1350+
// assume(ballot(x) == 0) -> x = false
1351+
if (Arg->getType()->isIntegerTy(1)) {
1352+
for (auto &AssumeVH : IC.getAssumptionCache().assumptionsFor(&II)) {
1353+
if (!AssumeVH)
1354+
continue;
1355+
1356+
auto *Assume = cast<AssumeInst>(AssumeVH);
1357+
Value *Cond = Assume->getArgOperand(0);
1358+
1359+
// Check if assume condition is an equality comparison.
1360+
auto *ICI = dyn_cast<ICmpInst>(Cond);
1361+
if (!ICI || ICI->getPredicate() != ICmpInst::ICMP_EQ)
1362+
continue;
1363+
1364+
// Extract the ballot and the value being compared against it.
1365+
Value *LHS = ICI->getOperand(0), *RHS = ICI->getOperand(1);
1366+
Value *CompareValue = (LHS == &II) ? RHS : (RHS == &II) ? LHS : nullptr;
1367+
if (!CompareValue)
1368+
continue;
1369+
1370+
// Determine the constant value of the ballot's condition argument.
1371+
std::optional<bool> PropagatedBool;
1372+
if (match(CompareValue, m_AllOnes()) ||
1373+
match(CompareValue,
1374+
m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One()))) {
1375+
// ballot(x) == -1 or ballot(x) == ballot(true) means x is true.
1376+
PropagatedBool = true;
1377+
} else if (match(CompareValue, m_Zero())) {
1378+
// ballot(x) == 0 means x is false.
1379+
PropagatedBool = false;
1380+
}
1381+
1382+
if (!PropagatedBool)
1383+
continue;
1384+
1385+
Constant *PropagatedValue =
1386+
ConstantInt::getBool(Arg->getContext(), *PropagatedBool);
1387+
1388+
// Replace dominated uses of the ballot's condition argument with the
1389+
// propagated value.
1390+
Arg->replaceUsesWithIf(PropagatedValue, [&](Use &U) {
1391+
Instruction *UserInst = dyn_cast<Instruction>(U.getUser());
1392+
return UserInst && IC.getDominatorTree().dominates(Assume, U);
1393+
});
1394+
}
1395+
}
1396+
13441397
break;
13451398
}
13461399
case Intrinsic::amdgcn_wavefrontsize: {

llvm/lib/Transforms/Scalar/GVN.cpp

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
#include "llvm/IR/Instruction.h"
5555
#include "llvm/IR/Instructions.h"
5656
#include "llvm/IR/IntrinsicInst.h"
57-
#include "llvm/IR/IntrinsicsAMDGPU.h"
5857
#include "llvm/IR/LLVMContext.h"
5958
#include "llvm/IR/Metadata.h"
6059
#include "llvm/IR/Module.h"
@@ -2540,63 +2539,6 @@ bool GVNPass::propagateEquality(
25402539
}
25412540
}
25422541

2543-
// Helper function to check if a value represents the current exec mask.
2544-
auto IsExecMask = [](Value *V) -> bool {
2545-
// Pattern 1: ballot(true)
2546-
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(V)) {
2547-
if (II->getIntrinsicID() == Intrinsic::AMDGCNIntrinsics::amdgcn_ballot) {
2548-
// Check if argument is constant true
2549-
if (match(II->getArgOperand(0), m_One())) {
2550-
return true;
2551-
}
2552-
}
2553-
}
2554-
2555-
return false;
2556-
};
2557-
2558-
// Check if either of the operands is a ballot intrinsic.
2559-
IntrinsicInst *BallotCall = nullptr;
2560-
Value *CompareValue = nullptr;
2561-
2562-
// Check both LHS and RHS for ballot intrinsic and its value since GVN may
2563-
// swap the operands.
2564-
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(LHS)) {
2565-
if (II->getIntrinsicID() == Intrinsic::AMDGCNIntrinsics::amdgcn_ballot) {
2566-
BallotCall = II;
2567-
CompareValue = RHS;
2568-
}
2569-
}
2570-
if (!BallotCall && isa<IntrinsicInst>(RHS)) {
2571-
IntrinsicInst *II = cast<IntrinsicInst>(RHS);
2572-
if (II->getIntrinsicID() == Intrinsic::AMDGCNIntrinsics::amdgcn_ballot) {
2573-
BallotCall = II;
2574-
CompareValue = LHS;
2575-
}
2576-
}
2577-
2578-
// If a ballot intrinsic is found, calculate the truth value of the ballot
2579-
// argument based on the RHS.
2580-
if (BallotCall) {
2581-
Value *BallotArg = BallotCall->getArgOperand(0);
2582-
if (BallotArg->getType()->isIntegerTy(1)) {
2583-
// Case 1: ballot(cond) == -1: cond true in all lanes -> cond = true.
2584-
// Case 2: ballot(cond) == exec_mask: cond true in all active lanes ->
2585-
// cond = true.
2586-
if (match(CompareValue, m_AllOnes()) || IsExecMask(CompareValue)) {
2587-
Worklist.push_back(std::make_pair(
2588-
BallotArg, ConstantInt::getTrue(BallotArg->getType())));
2589-
continue;
2590-
}
2591-
// Case 3: ballot(cond) == 0: cond false in all lanes -> cond = false.
2592-
if (match(CompareValue, m_Zero())) {
2593-
Worklist.push_back(std::make_pair(
2594-
BallotArg, ConstantInt::getFalse(BallotArg->getType())));
2595-
continue;
2596-
}
2597-
}
2598-
}
2599-
26002542
// Now try to deduce additional equalities from this one. For example, if
26012543
// the known equality was "(A != B)" == "false" then it follows that A and B
26022544
// are equal in the scope. Only boolean equalities with an explicit true or

llvm/test/Transforms/GVN/assume-ballot.ll renamed to llvm/test/Transforms/InstCombine/AMDGPU/llvm.amdgcn.ballot-assume.ll

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2-
; RUN: opt < %s -passes=gvn -S | FileCheck %s
2+
; RUN: opt < %s -mtriple=amdgcn-amd-amdhsa -passes=instcombine -S | FileCheck %s
33
;
4-
; Tests for assume-based ballot optimizations
5-
; This optimization recognizes patterns like:
6-
; assume(ballot(cmp) == -1) -> cmp is true on all lanes
7-
; assume(ballot(cmp) == 0) -> cmp is false on all lanes
4+
; Tests for assume-based ballot optimizations for patterns like:
5+
; assume(ballot(cmp) == -1) -> replace uses of cmp with true
6+
; assume(ballot(cmp) == 0) -> replace uses of cmp with false
7+
; assume(ballot(cmp) == ballot(1)) -> replace uses of cmp with true
88

99
declare void @llvm.assume(i1)
1010
declare i64 @llvm.amdgcn.ballot.i64(i1)
@@ -26,7 +26,6 @@ define amdgpu_kernel void @assume_ballot_all_lanes_i64(i32 %x, ptr addrspace(1)
2626
; CHECK-NEXT: store i32 1, ptr addrspace(1) [[OUT:%.*]], align 4
2727
; CHECK-NEXT: ret void
2828
; CHECK: bar:
29-
; CHECK-NEXT: store i32 0, ptr addrspace(1) [[OUT]], align 4
3029
; CHECK-NEXT: ret void
3130
;
3231
%cmp = icmp eq i32 %x, 0
@@ -70,7 +69,7 @@ define amdgpu_kernel void @assume_ballot_exec_mask_ballot_true(i32 %x, ptr addrs
7069
; CHECK-NEXT: [[EXEC:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 true)
7170
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], [[EXEC]]
7271
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
73-
; CHECK-NEXT: br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
72+
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
7473
; CHECK: foo:
7574
; CHECK-NEXT: ret void
7675
; CHECK: bar:
@@ -147,7 +146,7 @@ define amdgpu_kernel void @assume_ballot_exec_mask_wave32(i32 %x, ptr addrspace(
147146
; CHECK-NEXT: [[EXEC:%.*]] = call i32 @llvm.amdgcn.ballot.i32(i1 true)
148147
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i32 [[BALLOT]], [[EXEC]]
149148
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
150-
; CHECK-NEXT: br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
149+
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
151150
; CHECK: foo:
152151
; CHECK-NEXT: ret void
153152
; CHECK: bar:
@@ -175,7 +174,7 @@ define amdgpu_kernel void @assume_ballot_dominance(i32 %x, ptr addrspace(1) %out
175174
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
176175
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1
177176
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
178-
; CHECK-NEXT: [[OUT2:%.*]] = getelementptr i32, ptr addrspace(1) [[OUT]], i64 1
177+
; CHECK-NEXT: [[OUT2:%.*]] = getelementptr i8, ptr addrspace(1) [[OUT]], i64 4
179178
; CHECK-NEXT: store i32 1, ptr addrspace(1) [[OUT2]], align 4
180179
; CHECK-NEXT: ret void
181180
;
@@ -196,7 +195,7 @@ define amdgpu_kernel void @assume_ballot_swapped(i32 %x, ptr addrspace(1) %out)
196195
; CHECK-LABEL: @assume_ballot_swapped(
197196
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
198197
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
199-
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 -1, [[BALLOT]]
198+
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1
200199
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
201200
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
202201
; CHECK: foo:
@@ -224,7 +223,7 @@ define amdgpu_kernel void @assume_ballot_exec_mask_swapped(i32 %x, ptr addrspace
224223
; CHECK-NEXT: [[EXEC:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 true)
225224
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[EXEC]], [[BALLOT]]
226225
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
227-
; CHECK-NEXT: br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
226+
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
228227
; CHECK: foo:
229228
; CHECK-NEXT: ret void
230229
; CHECK: bar:
@@ -251,7 +250,7 @@ define amdgpu_kernel void @assume_ballot_multiple_uses(i32 %x, ptr addrspace(1)
251250
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1
252251
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
253252
; CHECK-NEXT: store i32 1, ptr addrspace(1) [[OUT:%.*]], align 4
254-
; CHECK-NEXT: [[OUT2:%.*]] = getelementptr i32, ptr addrspace(1) [[OUT]], i64 1
253+
; CHECK-NEXT: [[OUT2:%.*]] = getelementptr i8, ptr addrspace(1) [[OUT]], i64 4
255254
; CHECK-NEXT: store i32 10, ptr addrspace(1) [[OUT2]], align 4
256255
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
257256
; CHECK: foo:
@@ -284,12 +283,10 @@ define amdgpu_kernel void @assume_ballot_exec_mask_multiple_uses(i32 %x, ptr add
284283
; CHECK-NEXT: [[EXEC:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 true)
285284
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], [[EXEC]]
286285
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
287-
; CHECK-NEXT: [[USE1:%.*]] = zext i1 [[CMP]] to i32
288-
; CHECK-NEXT: store i32 [[USE1]], ptr addrspace(1) [[OUT:%.*]], align 4
289-
; CHECK-NEXT: [[USE2:%.*]] = select i1 [[CMP]], i32 10, i32 20
290-
; CHECK-NEXT: [[OUT2:%.*]] = getelementptr i32, ptr addrspace(1) [[OUT]], i64 1
291-
; CHECK-NEXT: store i32 [[USE2]], ptr addrspace(1) [[OUT2]], align 4
292-
; CHECK-NEXT: br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
286+
; CHECK-NEXT: store i32 1, ptr addrspace(1) [[OUT:%.*]], align 4
287+
; CHECK-NEXT: [[OUT2:%.*]] = getelementptr i8, ptr addrspace(1) [[OUT]], i64 4
288+
; CHECK-NEXT: store i32 10, ptr addrspace(1) [[OUT2]], align 4
289+
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
293290
; CHECK: foo:
294291
; CHECK-NEXT: ret void
295292
; CHECK: bar:
@@ -313,27 +310,24 @@ bar:
313310
ret void
314311
}
315312

316-
; ============================================================================
317-
; NEGATIVE CASES
318-
; ============================================================================
319-
320-
; Test 1: assume(ballot != -1) -> cmp should not be transformed (cmp is false in atleast one lane)
321-
define amdgpu_kernel void @assume_ballot_ne_negative(i32 %x, ptr addrspace(1) %out) {
322-
; CHECK-LABEL: @assume_ballot_ne_negative(
313+
; Test 12: ballot(cmp) == ballot(false) -> cmp replaced with false
314+
define amdgpu_kernel void @assume_ballot_false(i32 %x, ptr addrspace(1) %out) {
315+
; CHECK-LABEL: @assume_ballot_false(
323316
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
324317
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
325-
; CHECK-NEXT: [[NOT_ALL:%.*]] = icmp ne i64 [[BALLOT]], -1
326-
; CHECK-NEXT: call void @llvm.assume(i1 [[NOT_ALL]])
327-
; CHECK-NEXT: br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
318+
; CHECK-NEXT: [[MATCHES:%.*]] = icmp eq i64 [[BALLOT]], 0
319+
; CHECK-NEXT: call void @llvm.assume(i1 [[MATCHES]])
320+
; CHECK-NEXT: br i1 false, label [[FOO:%.*]], label [[BAR:%.*]]
328321
; CHECK: foo:
329322
; CHECK-NEXT: ret void
330323
; CHECK: bar:
331324
; CHECK-NEXT: ret void
332325
;
333326
%cmp = icmp eq i32 %x, 0
334327
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
335-
%not_all = icmp ne i64 %ballot, -1
336-
call void @llvm.assume(i1 %not_all)
328+
%not_exec = call i64 @llvm.amdgcn.ballot.i64(i1 false)
329+
%matches = icmp eq i64 %ballot, %not_exec
330+
call void @llvm.assume(i1 %matches)
337331
br i1 %cmp, label %foo, label %bar
338332

339333
foo:
@@ -342,13 +336,17 @@ bar:
342336
ret void
343337
}
344338

345-
; Test 2: assume(ballot != 0) -> cmp should not be transformed (cmp is true in atleast one lane)
346-
define amdgpu_kernel void @assume_ballot_ne_zero_negative(i32 %x, ptr addrspace(1) %out) {
347-
; CHECK-LABEL: @assume_ballot_ne_zero_negative(
339+
; ============================================================================
340+
; NEGATIVE CASES
341+
; ============================================================================
342+
343+
; Test 1: assume(ballot != -1) -> no transformation (requires icmp eq)
344+
define amdgpu_kernel void @assume_ballot_ne_negative(i32 %x, ptr addrspace(1) %out) {
345+
; CHECK-LABEL: @assume_ballot_ne_negative(
348346
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
349347
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
350-
; CHECK-NEXT: [[SOME:%.*]] = icmp ne i64 [[BALLOT]], 0
351-
; CHECK-NEXT: call void @llvm.assume(i1 [[SOME]])
348+
; CHECK-NEXT: [[NOT_ALL:%.*]] = icmp ne i64 [[BALLOT]], -1
349+
; CHECK-NEXT: call void @llvm.assume(i1 [[NOT_ALL]])
352350
; CHECK-NEXT: br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
353351
; CHECK: foo:
354352
; CHECK-NEXT: ret void
@@ -357,8 +355,8 @@ define amdgpu_kernel void @assume_ballot_ne_zero_negative(i32 %x, ptr addrspace(
357355
;
358356
%cmp = icmp eq i32 %x, 0
359357
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
360-
%some = icmp ne i64 %ballot, 0
361-
call void @llvm.assume(i1 %some)
358+
%not_all = icmp ne i64 %ballot, -1
359+
call void @llvm.assume(i1 %not_all)
362360
br i1 %cmp, label %foo, label %bar
363361

364362
foo:
@@ -367,14 +365,13 @@ bar:
367365
ret void
368366
}
369367

370-
; Test 3: ballot(cmp) == ballot(false) -> cmp should not be transformed (RHS is not EXEC MASK)
371-
define amdgpu_kernel void @assume_ballot_not_exec_mask(i32 %x, ptr addrspace(1) %out) {
372-
; CHECK-LABEL: @assume_ballot_not_exec_mask(
368+
; Test 2: assume(ballot != 0) -> no transformation (requires icmp eq)
369+
define amdgpu_kernel void @assume_ballot_ne_zero_negative(i32 %x, ptr addrspace(1) %out) {
370+
; CHECK-LABEL: @assume_ballot_ne_zero_negative(
373371
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
374372
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
375-
; CHECK-NEXT: [[NOT_EXEC:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 false)
376-
; CHECK-NEXT: [[MATCHES:%.*]] = icmp eq i64 [[BALLOT]], [[NOT_EXEC]]
377-
; CHECK-NEXT: call void @llvm.assume(i1 [[MATCHES]])
373+
; CHECK-NEXT: [[SOME:%.*]] = icmp ne i64 [[BALLOT]], 0
374+
; CHECK-NEXT: call void @llvm.assume(i1 [[SOME]])
378375
; CHECK-NEXT: br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
379376
; CHECK: foo:
380377
; CHECK-NEXT: ret void
@@ -383,9 +380,8 @@ define amdgpu_kernel void @assume_ballot_not_exec_mask(i32 %x, ptr addrspace(1)
383380
;
384381
%cmp = icmp eq i32 %x, 0
385382
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
386-
%not_exec = call i64 @llvm.amdgcn.ballot.i64(i1 false)
387-
%matches = icmp eq i64 %ballot, %not_exec
388-
call void @llvm.assume(i1 %matches)
383+
%some = icmp ne i64 %ballot, 0
384+
call void @llvm.assume(i1 %some)
389385
br i1 %cmp, label %foo, label %bar
390386

391387
foo:
@@ -394,7 +390,7 @@ bar:
394390
ret void
395391
}
396392

397-
; Test 4: Constant as mask value (other than -1 or 0) -> cmp should not be transformed
393+
; Test 3: Constant mask (other than -1/0) -> no transformation
398394
define amdgpu_kernel void @assume_ballot_constant_mask(i32 %x, ptr addrspace(1) %out) {
399395
; CHECK-LABEL: @assume_ballot_constant_mask(
400396
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
@@ -419,7 +415,7 @@ bar:
419415
ret void
420416
}
421417

422-
; Test 5: Arbitrary mask -> cmp should not be transformed
418+
; Test 4: Runtime mask value -> no transformation
423419
define amdgpu_kernel void @assume_ballot_arbitrary_mask(i32 %x, i64 %mask, ptr addrspace(1) %out) {
424420
; CHECK-LABEL: @assume_ballot_arbitrary_mask(
425421
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0

0 commit comments

Comments
 (0)