Skip to content

Commit 30c578a

Browse files
authored
[GVN] Teach GVN simple masked load/store forwarding (#157689)
This patch teaches GVN how to eliminate redundant masked loads and forward previous loads or instructions with a select. This is possible when the same mask is used for masked stores/loads that write to the same memory location
1 parent 72679c8 commit 30c578a

File tree

5 files changed

+253
-2
lines changed

5 files changed

+253
-2
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2773,6 +2773,14 @@ m_MaskedLoad(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2,
27732773
return m_Intrinsic<Intrinsic::masked_load>(Op0, Op1, Op2, Op3);
27742774
}
27752775

2776+
/// Matches MaskedStore Intrinsic.
2777+
template <typename Opnd0, typename Opnd1, typename Opnd2, typename Opnd3>
2778+
inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2, Opnd3>::Ty
2779+
m_MaskedStore(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2,
2780+
const Opnd3 &Op3) {
2781+
return m_Intrinsic<Intrinsic::masked_store>(Op0, Op1, Op2, Op3);
2782+
}
2783+
27762784
/// Matches MaskedGather Intrinsic.
27772785
template <typename Opnd0, typename Opnd1, typename Opnd2, typename Opnd3>
27782786
inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2, Opnd3>::Ty

llvm/include/llvm/Transforms/Scalar/GVN.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class OptimizationRemarkEmitter;
5656
class PHINode;
5757
class TargetLibraryInfo;
5858
class Value;
59+
class IntrinsicInst;
5960
/// A private "module" namespace for types and utilities used by GVN. These
6061
/// are implementation details and should not be used by clients.
6162
namespace LLVM_LIBRARY_VISIBILITY_NAMESPACE gvn {
@@ -349,6 +350,7 @@ class GVNPass : public PassInfoMixin<GVNPass> {
349350

350351
// Helper functions of redundant load elimination.
351352
bool processLoad(LoadInst *L);
353+
bool processMaskedLoad(IntrinsicInst *I);
352354
bool processNonLocalLoad(LoadInst *L);
353355
bool processAssumeIntrinsic(AssumeInst *II);
354356

llvm/lib/Transforms/Scalar/GVN.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2287,6 +2287,35 @@ bool GVNPass::processLoad(LoadInst *L) {
22872287
return true;
22882288
}
22892289

2290+
// Attempt to process masked loads which have loaded from
2291+
// masked stores with the same mask
2292+
bool GVNPass::processMaskedLoad(IntrinsicInst *I) {
2293+
if (!MD)
2294+
return false;
2295+
MemDepResult Dep = MD->getDependency(I);
2296+
Instruction *DepInst = Dep.getInst();
2297+
if (!DepInst || !Dep.isLocal() || !Dep.isDef())
2298+
return false;
2299+
2300+
Value *Mask = I->getOperand(2);
2301+
Value *Passthrough = I->getOperand(3);
2302+
Value *StoreVal;
2303+
if (!match(DepInst, m_MaskedStore(m_Value(StoreVal), m_Value(), m_Value(),
2304+
m_Specific(Mask))) ||
2305+
StoreVal->getType() != I->getType())
2306+
return false;
2307+
2308+
// Remove the load but generate a select for the passthrough
2309+
Value *OpToForward = llvm::SelectInst::Create(Mask, StoreVal, Passthrough, "",
2310+
I->getIterator());
2311+
2312+
ICF->removeUsersOf(I);
2313+
I->replaceAllUsesWith(OpToForward);
2314+
salvageAndRemoveInstruction(I);
2315+
++NumGVNLoad;
2316+
return true;
2317+
}
2318+
22902319
/// Return a pair the first field showing the value number of \p Exp and the
22912320
/// second field showing whether it is a value number newly created.
22922321
std::pair<uint32_t, bool>
@@ -2734,6 +2763,10 @@ bool GVNPass::processInstruction(Instruction *I) {
27342763
return false;
27352764
}
27362765

2766+
if (match(I, m_Intrinsic<Intrinsic::masked_load>()) &&
2767+
processMaskedLoad(cast<IntrinsicInst>(I)))
2768+
return true;
2769+
27372770
// For conditional branches, we can perform simple conditional propagation on
27382771
// the condition value itself.
27392772
if (BranchInst *BI = dyn_cast<BranchInst>(I)) {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes=gvn -S -enable-gvn-memdep=true < %s | FileCheck %s
3+
; RUN: opt -passes=gvn -S -enable-gvn-memdep=false < %s | FileCheck %s --check-prefix=MEMDEPFALSE
4+
5+
define <4 x float> @forward_binop_with_sel(ptr %0, ptr %1, i32 %a, i32 %b, <4 x float> %passthrough) {
6+
; CHECK-LABEL: @forward_binop_with_sel(
7+
; CHECK-NEXT: [[MASK:%.*]] = tail call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[A:%.*]], i32 [[B:%.*]])
8+
; CHECK-NEXT: [[LOAD_0_0:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr [[TMP0:%.*]], i32 1, <4 x i1> [[MASK]], <4 x float> zeroinitializer)
9+
; CHECK-NEXT: [[GEP_0_16:%.*]] = getelementptr i8, ptr [[TMP0]], i32 16
10+
; CHECK-NEXT: [[LOAD_0_16:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr [[GEP_0_16]], i32 1, <4 x i1> [[MASK]], <4 x float> zeroinitializer)
11+
; CHECK-NEXT: [[FMUL:%.*]] = fmul <4 x float> [[LOAD_0_0]], [[LOAD_0_16]]
12+
; CHECK-NEXT: call void @llvm.masked.store.v4f32.p0(<4 x float> [[FMUL]], ptr [[TMP1:%.*]], i32 1, <4 x i1> [[MASK]])
13+
; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[MASK]], <4 x float> [[FMUL]], <4 x float> [[PASSTHROUGH:%.*]]
14+
; CHECK-NEXT: ret <4 x float> [[TMP3]]
15+
;
16+
; MEMDEPFALSE-LABEL: @forward_binop_with_sel(
17+
; MEMDEPFALSE-NEXT: [[MASK:%.*]] = tail call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[A:%.*]], i32 [[B:%.*]])
18+
; MEMDEPFALSE-NEXT: [[LOAD_0_0:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr [[TMP0:%.*]], i32 1, <4 x i1> [[MASK]], <4 x float> zeroinitializer)
19+
; MEMDEPFALSE-NEXT: [[GEP_0_16:%.*]] = getelementptr i8, ptr [[TMP0]], i32 16
20+
; MEMDEPFALSE-NEXT: [[LOAD_0_16:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr [[GEP_0_16]], i32 1, <4 x i1> [[MASK]], <4 x float> zeroinitializer)
21+
; MEMDEPFALSE-NEXT: [[FMUL:%.*]] = fmul <4 x float> [[LOAD_0_0]], [[LOAD_0_16]]
22+
; MEMDEPFALSE-NEXT: call void @llvm.masked.store.v4f32.p0(<4 x float> [[FMUL]], ptr [[TMP1:%.*]], i32 1, <4 x i1> [[MASK]])
23+
; MEMDEPFALSE-NEXT: [[LOAD_1_0:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr [[TMP1]], i32 1, <4 x i1> [[MASK]], <4 x float> [[PASSTHROUGH:%.*]])
24+
; MEMDEPFALSE-NEXT: ret <4 x float> [[LOAD_1_0]]
25+
;
26+
%mask = tail call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %a, i32 %b)
27+
%load.0.0 = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %0, i32 1, <4 x i1> %mask, <4 x float> zeroinitializer)
28+
%gep.0.16 = getelementptr i8, ptr %0, i32 16
29+
%load.0.16 = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %gep.0.16, i32 1, <4 x i1> %mask, <4 x float> zeroinitializer)
30+
%fmul = fmul <4 x float> %load.0.0, %load.0.16
31+
call void @llvm.masked.store.v4f32.p0(<4 x float> %fmul, ptr %1, i32 1, <4 x i1> %mask)
32+
%load.1.0 = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %1, i32 1, <4 x i1> %mask, <4 x float> %passthrough)
33+
ret <4 x float> %load.1.0
34+
}

0 commit comments

Comments
 (0)