Skip to content

Commit c335ed6

Browse files
committed
Use PatternMatch
1 parent 11ac166 commit c335ed6

File tree

2 files changed

+27
-31
lines changed

2 files changed

+27
-31
lines changed

llvm/lib/Transforms/Scalar/GVN.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,30 +2299,26 @@ bool GVNPass::processMaskedLoad(IntrinsicInst *I) {
22992299
if (!DepInst || !Dep.isLocal())
23002300
return false;
23012301

2302-
auto *MaskedStore = dyn_cast<IntrinsicInst>(DepInst);
2303-
if (!MaskedStore || MaskedStore->getIntrinsicID() != Intrinsic::masked_store)
2302+
Value *StoreVal;
2303+
if (!match(DepInst,
2304+
m_Intrinsic<Intrinsic::masked_store>(m_Value(StoreVal), m_Value(),
2305+
m_Value(), m_Specific(Mask))))
23042306
return false;
23052307

2306-
auto StoreMask = MaskedStore->getOperand(3);
2307-
if (StoreMask != Mask)
2308-
return false;
2309-
2310-
Value *OpToForward =
2311-
AvailableValue::get(MaskedStore->getOperand(0)).getSimpleValue();
2312-
if (auto *LoadToForward = dyn_cast<IntrinsicInst>(OpToForward);
2313-
LoadToForward &&
2314-
LoadToForward->getIntrinsicID() == Intrinsic::masked_load) {
2308+
Value *OpToForward = nullptr;
2309+
if (match(StoreVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(Mask),
2310+
m_Specific(Passthrough))))
23152311
// For MaskedLoad->MaskedStore->MaskedLoad, the mask must be the same for
23162312
// all three instructions. The Passthrough on the two loads must also be the
23172313
// same.
2318-
if (LoadToForward->getOperand(2) != Mask ||
2319-
LoadToForward->getOperand(3) != Passthrough)
2320-
return false;
2321-
} else {
2314+
OpToForward = AvailableValue::get(StoreVal).getSimpleValue();
2315+
else if (match(StoreVal, m_Intrinsic<Intrinsic::masked_load>()))
2316+
return false;
2317+
else {
23222318
// MaskedStore(Op, ptr, mask)->MaskedLoad(ptr, mask, passthrough) can be
23232319
// replaced with MaskedStore(Op, ptr, mask)->select(mask, Op, passthrough)
23242320
IRBuilder<> Builder(I);
2325-
OpToForward = Builder.CreateSelect(StoreMask, OpToForward, Passthrough);
2321+
OpToForward = Builder.CreateSelect(Mask, StoreVal, Passthrough);
23262322
}
23272323

23282324
I->replaceAllUsesWith(OpToForward);

llvm/test/Transforms/GVN/masked-load-store.ll

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ define <4 x float> @forward_masked_load(ptr %0, ptr %1) {
4242
; CHECK-NEXT: call void @llvm.masked.store.v4f32.p0(<4 x float> [[TMP4]], ptr [[TMP1:%.*]], i32 1, <4 x i1> splat (i1 true))
4343
; CHECK-NEXT: ret <4 x float> [[TMP4]]
4444
;
45-
%6 = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 0, i32 4)
46-
%7 = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %0, i32 1, <4 x i1> %6, <4 x float> zeroinitializer)
47-
call void @llvm.masked.store.v4f32.p0(<4 x float> %7, ptr %1, i32 1, <4 x i1> %6)
48-
%8 = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %1, i32 1, <4 x i1> %6, <4 x float> zeroinitializer)
49-
ret <4 x float> %8
45+
%mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 0, i32 4)
46+
%load1 = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %0, i32 1, <4 x i1> %mask, <4 x float> zeroinitializer)
47+
call void @llvm.masked.store.v4f32.p0(<4 x float> %load1, ptr %1, i32 1, <4 x i1> %mask)
48+
%load2 = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %1, i32 1, <4 x i1> %mask, <4 x float> zeroinitializer)
49+
ret <4 x float> %load2
5050
}
5151

5252
define <4 x float> @forward_binop_splat_i1_mask(ptr %0, ptr %1) {
@@ -96,11 +96,11 @@ define <vscale x 4 x float> @forward_masked_load_scalable(ptr %0, ptr %1, <vscal
9696
; CHECK-NEXT: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[TMP4]], ptr [[TMP1:%.*]], i32 1, <vscale x 4 x i1> [[TMP3]])
9797
; CHECK-NEXT: ret <vscale x 4 x float> [[TMP4]]
9898
;
99-
%6 = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
100-
%7 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %0, i32 1, <vscale x 4 x i1> %6, <vscale x 4 x float> %passthrough)
101-
call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %7, ptr %1, i32 1, <vscale x 4 x i1> %6)
102-
%8 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %1, i32 1, <vscale x 4 x i1> %6, <vscale x 4 x float> %passthrough)
103-
ret <vscale x 4 x float> %8
99+
%mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
100+
%load1 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %0, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough)
101+
call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load1, ptr %1, i32 1, <vscale x 4 x i1> %mask)
102+
%load2 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %1, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough)
103+
ret <vscale x 4 x float> %load2
104104
}
105105

106106
define <vscale x 4 x float> @bail_on_different_passthrough(ptr %0, ptr %1, <vscale x 4 x float> %passthrough) {
@@ -111,11 +111,11 @@ define <vscale x 4 x float> @bail_on_different_passthrough(ptr %0, ptr %1, <vsca
111111
; CHECK-NEXT: [[TMP5:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[TMP1]], i32 1, <vscale x 4 x i1> [[TMP3]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
112112
; CHECK-NEXT: ret <vscale x 4 x float> [[TMP5]]
113113
;
114-
%6 = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
115-
%7 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %0, i32 1, <vscale x 4 x i1> %6, <vscale x 4 x float> zeroinitializer)
116-
call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %7, ptr %1, i32 1, <vscale x 4 x i1> %6)
117-
%8 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %1, i32 1, <vscale x 4 x i1> %6, <vscale x 4 x float> %passthrough)
118-
ret <vscale x 4 x float> %8
114+
%mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
115+
%load1 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %0, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
116+
call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load1, ptr %1, i32 1, <vscale x 4 x i1> %mask)
117+
%load2 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %1, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough)
118+
ret <vscale x 4 x float> %load2
119119
}
120120

121121
define <vscale x 4 x float> @forward_binop_with_sel_scalable(ptr %0, ptr %1, <vscale x 4 x float> %passthrough) {

0 commit comments

Comments
 (0)