Skip to content

Commit 8f7fe4a

Browse files
committed
Respond to david-arm's review comments
1 parent 0d7107f commit 8f7fe4a

File tree

4 files changed

+63
-32
lines changed

4 files changed

+63
-32
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2781,6 +2781,14 @@ m_MaskedLoad(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2,
27812781
return m_Intrinsic<Intrinsic::masked_load>(Op0, Op1, Op2, Op3);
27822782
}
27832783

2784+
/// Matches MaskedStore Intrinsic.
2785+
template <typename Opnd0, typename Opnd1, typename Opnd2, typename Opnd3>
2786+
inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2, Opnd3>::Ty
2787+
m_MaskedStore(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2,
2788+
const Opnd3 &Op3) {
2789+
return m_Intrinsic<Intrinsic::masked_store>(Op0, Op1, Op2, Op3);
2790+
}
2791+
27842792
/// Matches MaskedGather Intrinsic.
27852793
template <typename Opnd0, typename Opnd1, typename Opnd2, typename Opnd3>
27862794
inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2, Opnd3>::Ty

llvm/lib/Transforms/Scalar/GVN.cpp

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
#include "llvm/IR/DebugLoc.h"
5151
#include "llvm/IR/Dominators.h"
5252
#include "llvm/IR/Function.h"
53-
#include "llvm/IR/IRBuilder.h"
5453
#include "llvm/IR/InstrTypes.h"
5554
#include "llvm/IR/Instruction.h"
5655
#include "llvm/IR/Instructions.h"
@@ -2291,39 +2290,29 @@ bool GVNPass::processLoad(LoadInst *L) {
22912290
// Attempt to process masked loads which have loaded from
22922291
// masked stores with the same mask
22932292
bool GVNPass::processMaskedLoad(IntrinsicInst *I) {
2294-
Value *Mask = I->getOperand(2);
2295-
Value *Passthrough = I->getOperand(3);
2296-
2293+
if (!MD)
2294+
return false;
22972295
MemDepResult Dep = MD->getDependency(I);
22982296
Instruction *DepInst = Dep.getInst();
2299-
if (!DepInst || !Dep.isLocal())
2297+
if (!DepInst || !Dep.isLocal() || !Dep.isDef())
23002298
return false;
23012299

2300+
Value *Mask = I->getOperand(2);
2301+
Value *Passthrough = I->getOperand(3);
23022302
Value *StoreVal;
2303-
if (!match(DepInst,
2304-
m_Intrinsic<Intrinsic::masked_store>(m_Value(StoreVal), m_Value(),
2305-
m_Value(), m_Specific(Mask))))
2303+
if (!match(DepInst, m_MaskedStore(m_Value(StoreVal), m_Value(), m_Value(),
2304+
m_Specific(Mask))))
23062305
return false;
23072306

2308-
Value *OpToForward = nullptr;
2309-
if (match(StoreVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(Mask),
2310-
m_Specific(Passthrough))))
2311-
// For MaskedLoad->MaskedStore->MaskedLoad, the mask must be the same for
2312-
// all three instructions. The Passthrough on the two loads must also be the
2313-
// same.
2314-
OpToForward = AvailableValue::get(StoreVal).getSimpleValue();
2315-
else if (match(StoreVal, m_Intrinsic<Intrinsic::masked_load>()))
2316-
return false;
2317-
else {
2318-
// MaskedStore(Op, ptr, mask)->MaskedLoad(ptr, mask, passthrough) can be
2319-
// replaced with MaskedStore(Op, ptr, mask)->select(mask, Op, passthrough)
2320-
IRBuilder<> Builder(I);
2321-
OpToForward = Builder.CreateSelect(Mask, StoreVal, Passthrough);
2322-
}
2307+
// Remove the load but generate a select for the passthrough
2308+
Value *OpToForward = llvm::SelectInst::Create(Mask, StoreVal, Passthrough, "",
2309+
I->getIterator());
23232310

2324-
I->replaceAllUsesWith(OpToForward);
23252311
ICF->removeUsersOf(I);
2312+
I->replaceAllUsesWith(OpToForward);
23262313
salvageAndRemoveInstruction(I);
2314+
if (OpToForward->getType()->isPtrOrPtrVectorTy())
2315+
MD->invalidateCachedPointerInfo(OpToForward);
23272316
++NumGVNLoad;
23282317
return true;
23292318
}
@@ -2775,10 +2764,9 @@ bool GVNPass::processInstruction(Instruction *I) {
27752764
return false;
27762765
}
27772766

2778-
if (auto *II = dyn_cast<IntrinsicInst>(I))
2779-
if (II && II->getIntrinsicID() == Intrinsic::masked_load)
2780-
if (processMaskedLoad(II))
2781-
return true;
2767+
if (match(I, m_Intrinsic<Intrinsic::masked_load>()) &&
2768+
processMaskedLoad(cast<IntrinsicInst>(I)))
2769+
return true;
27822770

27832771
// For conditional branches, we can perform simple conditional propagation on
27842772
// the condition value itself.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes=gvn -S -enable-gvn-memdep=true < %s | FileCheck %s
3+
; RUN: opt -passes=gvn -S -enable-gvn-memdep=false < %s | FileCheck %s --check-prefix=MEMDEPFALSE
4+
5+
define <4 x float> @forward_binop_with_sel(ptr %0, ptr %1, i32 %a, i32 %b, <4 x float> %passthrough) {
6+
; CHECK-LABEL: @forward_binop_with_sel(
7+
; CHECK-NEXT: [[MASK:%.*]] = tail call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[A:%.*]], i32 [[B:%.*]])
8+
; CHECK-NEXT: [[LOAD_0_0:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr [[TMP0:%.*]], i32 1, <4 x i1> [[MASK]], <4 x float> zeroinitializer)
9+
; CHECK-NEXT: [[GEP_0_16:%.*]] = getelementptr i8, ptr [[TMP0]], i32 16
10+
; CHECK-NEXT: [[LOAD_0_16:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr [[GEP_0_16]], i32 1, <4 x i1> [[MASK]], <4 x float> zeroinitializer)
11+
; CHECK-NEXT: [[FMUL:%.*]] = fmul <4 x float> [[LOAD_0_0]], [[LOAD_0_16]]
12+
; CHECK-NEXT: call void @llvm.masked.store.v4f32.p0(<4 x float> [[FMUL]], ptr [[TMP1:%.*]], i32 1, <4 x i1> [[MASK]])
13+
; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[MASK]], <4 x float> [[FMUL]], <4 x float> [[PASSTHROUGH:%.*]]
14+
; CHECK-NEXT: ret <4 x float> [[TMP3]]
15+
;
16+
; MEMDEPFALSE-LABEL: @forward_binop_with_sel(
17+
; MEMDEPFALSE-NEXT: [[MASK:%.*]] = tail call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[A:%.*]], i32 [[B:%.*]])
18+
; MEMDEPFALSE-NEXT: [[LOAD_0_0:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr [[TMP0:%.*]], i32 1, <4 x i1> [[MASK]], <4 x float> zeroinitializer)
19+
; MEMDEPFALSE-NEXT: [[GEP_0_16:%.*]] = getelementptr i8, ptr [[TMP0]], i32 16
20+
; MEMDEPFALSE-NEXT: [[LOAD_0_16:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr [[GEP_0_16]], i32 1, <4 x i1> [[MASK]], <4 x float> zeroinitializer)
21+
; MEMDEPFALSE-NEXT: [[FMUL:%.*]] = fmul <4 x float> [[LOAD_0_0]], [[LOAD_0_16]]
22+
; MEMDEPFALSE-NEXT: call void @llvm.masked.store.v4f32.p0(<4 x float> [[FMUL]], ptr [[TMP1:%.*]], i32 1, <4 x i1> [[MASK]])
23+
; MEMDEPFALSE-NEXT: [[LOAD_1_0:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr [[TMP1]], i32 1, <4 x i1> [[MASK]], <4 x float> [[PASSTHROUGH:%.*]])
24+
; MEMDEPFALSE-NEXT: ret <4 x float> [[LOAD_1_0]]
25+
;
26+
%mask = tail call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %a, i32 %b)
27+
%load.0.0 = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %0, i32 1, <4 x i1> %mask, <4 x float> zeroinitializer)
28+
%gep.0.16 = getelementptr i8, ptr %0, i32 16
29+
%load.0.16 = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %gep.0.16, i32 1, <4 x i1> %mask, <4 x float> zeroinitializer)
30+
%fmul = fmul <4 x float> %load.0.0, %load.0.16
31+
call void @llvm.masked.store.v4f32.p0(<4 x float> %fmul, ptr %1, i32 1, <4 x i1> %mask)
32+
%load.1.0 = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %1, i32 1, <4 x i1> %mask, <4 x float> %passthrough)
33+
ret <4 x float> %load.1.0
34+
}

llvm/test/Transforms/GVN/masked-load-store.ll

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ define <vscale x 4 x float> @forward_masked_load_scalable(ptr %0, ptr %1, <vscal
9494
; CHECK-NEXT: [[TMP3:%.*]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
9595
; CHECK-NEXT: [[TMP4:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[TMP0:%.*]], i32 1, <vscale x 4 x i1> [[TMP3]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
9696
; CHECK-NEXT: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[TMP4]], ptr [[TMP1:%.*]], i32 1, <vscale x 4 x i1> [[TMP3]])
97-
; CHECK-NEXT: ret <vscale x 4 x float> [[TMP4]]
97+
; CHECK-NEXT: [[TMP5:%.*]] = select <vscale x 4 x i1> [[TMP3]], <vscale x 4 x float> [[TMP4]], <vscale x 4 x float> [[PASSTHROUGH]]
98+
; CHECK-NEXT: ret <vscale x 4 x float> [[TMP5]]
9899
;
99100
%mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
100101
%load1 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %0, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough)
@@ -103,12 +104,12 @@ define <vscale x 4 x float> @forward_masked_load_scalable(ptr %0, ptr %1, <vscal
103104
ret <vscale x 4 x float> %load2
104105
}
105106

106-
define <vscale x 4 x float> @bail_on_different_passthrough(ptr %0, ptr %1, <vscale x 4 x float> %passthrough) {
107-
; CHECK-LABEL: @bail_on_different_passthrough(
107+
define <vscale x 4 x float> @generate_sel_with_passthrough(ptr %0, ptr %1, <vscale x 4 x float> %passthrough) {
108+
; CHECK-LABEL: @generate_sel_with_passthrough(
108109
; CHECK-NEXT: [[TMP3:%.*]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
109110
; CHECK-NEXT: [[TMP4:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[TMP0:%.*]], i32 1, <vscale x 4 x i1> [[TMP3]], <vscale x 4 x float> zeroinitializer)
110111
; CHECK-NEXT: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[TMP4]], ptr [[TMP1:%.*]], i32 1, <vscale x 4 x i1> [[TMP3]])
111-
; CHECK-NEXT: [[TMP5:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[TMP1]], i32 1, <vscale x 4 x i1> [[TMP3]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
112+
; CHECK-NEXT: [[TMP5:%.*]] = select <vscale x 4 x i1> [[TMP3]], <vscale x 4 x float> [[TMP4]], <vscale x 4 x float> [[PASSTHROUGH:%.*]]
112113
; CHECK-NEXT: ret <vscale x 4 x float> [[TMP5]]
113114
;
114115
%mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)

0 commit comments

Comments
 (0)