Skip to content

Commit eb85899

Browse files
authored
[InstCombine] Fold selects into masked loads (#160522)
Selects can be folded into masked loads if the masks are identical.
1 parent e3aa00e commit eb85899

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4611,5 +4611,15 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
46114611
return replaceOperand(SI, 2, ConstantInt::get(FalseVal->getType(), 0));
46124612
}
46134613

4614+
Value *MaskedLoadPtr;
4615+
const APInt *MaskedLoadAlignment;
4616+
if (match(TrueVal, m_OneUse(m_MaskedLoad(m_Value(MaskedLoadPtr),
4617+
m_APInt(MaskedLoadAlignment),
4618+
m_Specific(CondVal), m_Value()))))
4619+
return replaceInstUsesWith(
4620+
SI, Builder.CreateMaskedLoad(TrueVal->getType(), MaskedLoadPtr,
4621+
Align(MaskedLoadAlignment->getZExtValue()),
4622+
CondVal, FalseVal));
4623+
46144624
return nullptr;
46154625
}

llvm/test/Transforms/InstCombine/select-masked_load.ll

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ define <4 x i32> @masked_load_and_zero_inactive_2(ptr %ptr, <4 x i1> %mask) {
2626
; No transform when the load's passthrough cannot be reused or altered.
2727
define <4 x i32> @masked_load_and_zero_inactive_3(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthrough) {
2828
; CHECK-LABEL: @masked_load_and_zero_inactive_3(
29-
; CHECK-NEXT: [[LOAD:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[PTR:%.*]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHROUGH:%.*]])
30-
; CHECK-NEXT: [[MASKED:%.*]] = select <4 x i1> [[MASK]], <4 x i32> [[LOAD]], <4 x i32> zeroinitializer
29+
; CHECK-NEXT: [[MASKED:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[PTR:%.*]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> zeroinitializer)
3130
; CHECK-NEXT: ret <4 x i32> [[MASKED]]
3231
;
3332
%load = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr %ptr, i32 4, <4 x i1> %mask, <4 x i32> %passthrough)
@@ -116,6 +115,40 @@ entry:
116115
ret <8 x float> %1
117116
}
118117

118+
define <vscale x 4 x float> @fold_sel_into_masked_load_scalable(ptr %loc, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough) {
119+
; CHECK-LABEL: @fold_sel_into_masked_load_scalable(
120+
; CHECK-NEXT: [[SEL:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
121+
; CHECK-NEXT: ret <vscale x 4 x float> [[SEL]]
122+
;
123+
%load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %loc, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
124+
%sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %load, <vscale x 4 x float> %passthrough
125+
ret <vscale x 4 x float> %sel
126+
}
127+
128+
define <vscale x 4 x float> @neg_fold_sel_into_masked_load_mask_mismatch(ptr %loc, <vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask2, <vscale x 4 x float> %passthrough) {
129+
; CHECK-LABEL: @neg_fold_sel_into_masked_load_mask_mismatch(
130+
; CHECK-NEXT: [[LOAD:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
131+
; CHECK-NEXT: [[SEL:%.*]] = select <vscale x 4 x i1> [[MASK2:%.*]], <vscale x 4 x float> [[LOAD]], <vscale x 4 x float> [[PASSTHROUGH]]
132+
; CHECK-NEXT: ret <vscale x 4 x float> [[SEL]]
133+
;
134+
%load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %loc, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough)
135+
%sel = select <vscale x 4 x i1> %mask2, <vscale x 4 x float> %load, <vscale x 4 x float> %passthrough
136+
ret <vscale x 4 x float> %sel
137+
}
138+
139+
define <vscale x 4 x float> @fold_sel_into_masked_load_scalable_one_use_check(ptr %loc1, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough, ptr %loc2) {
140+
; CHECK-LABEL: @fold_sel_into_masked_load_scalable_one_use_check(
141+
; CHECK-NEXT: [[LOAD:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> zeroinitializer)
142+
; CHECK-NEXT: [[SEL:%.*]] = select <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> [[LOAD]], <vscale x 4 x float> [[PASSTHROUGH:%.*]]
143+
; CHECK-NEXT: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[LOAD]], ptr [[LOC2:%.*]], i32 1, <vscale x 4 x i1> [[MASK]])
144+
; CHECK-NEXT: ret <vscale x 4 x float> [[SEL]]
145+
;
146+
%load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %loc1, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
147+
%sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %load, <vscale x 4 x float> %passthrough
148+
call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load, ptr %loc2, i32 1, <vscale x 4 x i1> %mask)
149+
ret <vscale x 4 x float> %sel
150+
}
151+
119152
declare <8 x float> @llvm.masked.load.v8f32.p0(ptr, i32 immarg, <8 x i1>, <8 x float>)
120153
declare <4 x i32> @llvm.masked.load.v4i32.p0(ptr, i32 immarg, <4 x i1>, <4 x i32>)
121154
declare <4 x float> @llvm.masked.load.v4f32.p0(ptr, i32 immarg, <4 x i1>, <4 x float>)

0 commit comments

Comments
 (0)