Skip to content

Commit 36b543a

Browse files
authored
[InstComb] Handle undef in simplifyMasked(Store|Scatter) (#161825)
1 parent 2d67cb1 commit 36b543a

File tree

3 files changed

+63
-7
lines changed

3 files changed

+63
-7
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,18 +318,18 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
318318
// * Single constant active lane -> store
319319
// * Narrow width by halfs excluding zero/undef lanes
320320
Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) {
321+
Value *StorePtr = II.getArgOperand(1);
322+
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
321323
auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
322324
if (!ConstMask)
323325
return nullptr;
324326

325327
// If the mask is all zeros, this instruction does nothing.
326-
if (ConstMask->isNullValue())
328+
if (maskIsAllZeroOrUndef(ConstMask))
327329
return eraseInstFromFunction(II);
328330

329331
// If the mask is all ones, this is a plain vector store of the 1st argument.
330-
if (ConstMask->isAllOnesValue()) {
331-
Value *StorePtr = II.getArgOperand(1);
332-
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
332+
if (maskIsAllOneOrUndef(ConstMask)) {
333333
StoreInst *S =
334334
new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment);
335335
S->copyMetadata(II);
@@ -389,7 +389,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
389389
return nullptr;
390390

391391
// If the mask is all zeros, a scatter does nothing.
392-
if (ConstMask->isNullValue())
392+
if (maskIsAllZeroOrUndef(ConstMask))
393393
return eraseInstFromFunction(II);
394394

395395
// Vector splat address -> scalar store

llvm/test/Transforms/InstCombine/masked_intrinsics.ll

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ define <2 x double> @load_zeromask(ptr %ptr, <2 x double> %passthru) {
1616
ret <2 x double> %res
1717
}
1818

19+
define <2 x double> @load_zero_withpoison_mask(ptr %ptr, <2 x double> %passthru) {
20+
; CHECK-LABEL: @load_zero_withpoison_mask(
21+
; CHECK-NEXT: ret <2 x double> [[PASSTHRU:%.*]]
22+
;
23+
%res = call <2 x double> @llvm.masked.load.v2f64.p0(ptr %ptr, i32 1, <2 x i1> <i1 0, i1 poison>, <2 x double> %passthru)
24+
ret <2 x double> %res
25+
}
26+
1927
define <2 x double> @load_onemask(ptr %ptr, <2 x double> %passthru) {
2028
; CHECK-LABEL: @load_onemask(
2129
; CHECK-NEXT: [[UNMASKEDLOAD:%.*]] = load <2 x double>, ptr [[PTR:%.*]], align 2
@@ -150,6 +158,14 @@ define void @store_zeromask(ptr %ptr, <2 x double> %val) {
150158
ret void
151159
}
152160

161+
define void @store_poisonmask(ptr %ptr, <2 x double> %val) {
162+
; CHECK-LABEL: @store_poisonmask(
163+
; CHECK-NEXT: ret void
164+
;
165+
call void @llvm.masked.store.v2f64.p0(<2 x double> %val, ptr %ptr, i32 4, <2 x i1> splat(i1 poison))
166+
ret void
167+
}
168+
153169
define void @store_onemask(ptr %ptr, <2 x double> %val) {
154170
; CHECK-LABEL: @store_onemask(
155171
; CHECK-NEXT: store <2 x double> [[VAL:%.*]], ptr [[PTR:%.*]], align 4
@@ -159,6 +175,15 @@ define void @store_onemask(ptr %ptr, <2 x double> %val) {
159175
ret void
160176
}
161177

178+
define void @store_one_withpoison_mask(ptr %ptr, <2 x double> %val) {
179+
; CHECK-LABEL: @store_one_withpoison_mask(
180+
; CHECK-NEXT: store <2 x double> [[VAL:%.*]], ptr [[PTR:%.*]], align 4
181+
; CHECK-NEXT: ret void
182+
;
183+
call void @llvm.masked.store.v2f64.p0(<2 x double> %val, ptr %ptr, i32 4, <2 x i1> <i1 1, i1 poison>)
184+
ret void
185+
}
186+
162187
define void @store_demandedelts(ptr %ptr, double %val) {
163188
; CHECK-LABEL: @store_demandedelts(
164189
; CHECK-NEXT: [[VALVEC1:%.*]] = insertelement <2 x double> poison, double [[VAL:%.*]], i64 0
@@ -189,6 +214,13 @@ define <2 x double> @gather_zeromask(<2 x ptr> %ptrs, <2 x double> %passthru) {
189214
ret <2 x double> %res
190215
}
191216

217+
define <2 x double> @gather_zero_withpoison_mask(<2 x ptr> %ptrs, <2 x double> %passthru) {
218+
; CHECK-LABEL: @gather_zero_withpoison_mask(
219+
; CHECK-NEXT: ret <2 x double> [[PASSTHRU:%.*]]
220+
;
221+
%res = call <2 x double> @llvm.masked.gather.v2f64.v2p0(<2 x ptr> %ptrs, i32 4, <2 x i1> <i1 0, i1 poison>, <2 x double> %passthru)
222+
ret <2 x double> %res
223+
}
192224

193225
define <2 x double> @gather_onemask(<2 x ptr> %ptrs, <2 x double> %passthru) {
194226
; CHECK-LABEL: @gather_onemask(
@@ -199,6 +231,15 @@ define <2 x double> @gather_onemask(<2 x ptr> %ptrs, <2 x double> %passthru) {
199231
ret <2 x double> %res
200232
}
201233

234+
define <2 x double> @gather_one_withpoisonmask(<2 x ptr> %ptrs, <2 x double> %passthru) {
235+
; CHECK-LABEL: @gather_one_withpoisonmask(
236+
; CHECK-NEXT: [[RES:%.*]] = call <2 x double> @llvm.masked.gather.v2f64.v2p0(<2 x ptr> [[PTRS:%.*]], i32 4, <2 x i1> <i1 true, i1 poison>, <2 x double> [[PASSTHRU:%.*]])
237+
; CHECK-NEXT: ret <2 x double> [[RES]]
238+
;
239+
%res = call <2 x double> @llvm.masked.gather.v2f64.v2p0(<2 x ptr> %ptrs, i32 4, <2 x i1> <i1 true, i1 poison>, <2 x double> %passthru)
240+
ret <2 x double> %res
241+
}
242+
202243
define <4 x double> @gather_lane2(ptr %base, double %pt) {
203244
; CHECK-LABEL: @gather_lane2(
204245
; CHECK-NEXT: [[PTRS:%.*]] = getelementptr double, ptr [[BASE:%.*]], <4 x i64> <i64 poison, i64 poison, i64 2, i64 poison>
@@ -257,6 +298,23 @@ define void @scatter_zeromask(<2 x ptr> %ptrs, <2 x double> %val) {
257298
ret void
258299
}
259300

301+
define void @scatter_zero_withpoison_mask(<2 x ptr> %ptrs, <2 x double> %val) {
302+
; CHECK-LABEL: @scatter_zero_withpoison_mask(
303+
; CHECK-NEXT: ret void
304+
;
305+
call void @llvm.masked.scatter.v2f64.v2p0(<2 x double> %val, <2 x ptr> %ptrs, i32 8, <2 x i1> <i1 0, i1 poison>)
306+
ret void
307+
}
308+
309+
define void @scatter_one_withpoison_mask(<2 x ptr> %ptrs, <2 x double> %val) {
310+
; CHECK-LABEL: @scatter_one_withpoison_mask(
311+
; CHECK-NEXT: call void @llvm.masked.scatter.v2f64.v2p0(<2 x double> [[VAL:%.*]], <2 x ptr> [[PTRS:%.*]], i32 8, <2 x i1> <i1 true, i1 poison>)
312+
; CHECK-NEXT: ret void
313+
;
314+
call void @llvm.masked.scatter.v2f64.v2p0(<2 x double> %val, <2 x ptr> %ptrs, i32 8, <2 x i1> <i1 1, i1 poison>)
315+
ret void
316+
}
317+
260318
define void @scatter_demandedelts(ptr %ptr, double %val) {
261319
; CHECK-LABEL: @scatter_demandedelts(
262320
; CHECK-NEXT: [[PTRS:%.*]] = getelementptr double, ptr [[PTR:%.*]], <2 x i64> <i64 0, i64 poison>

llvm/test/Transforms/InstCombine/pr83947.ll

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ define void @masked_scatter2() {
2424

2525
define void @masked_scatter3() {
2626
; CHECK-LABEL: define void @masked_scatter3() {
27-
; CHECK-NEXT: store i32 0, ptr @c, align 4
2827
; CHECK-NEXT: ret void
2928
;
3029
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> undef)
@@ -50,7 +49,6 @@ define void @masked_scatter5() {
5049

5150
define void @masked_scatter6() {
5251
; CHECK-LABEL: define void @masked_scatter6() {
53-
; CHECK-NEXT: store i32 0, ptr @c, align 4
5452
; CHECK-NEXT: ret void
5553
;
5654
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 undef, i1 false>)

0 commit comments

Comments
 (0)