Skip to content

Commit 47811e8

Browse files
committed
[SelectionDAG] Deal with POISON for INSERT_VECTOR_ELT/INSERT_SUBVECTOR (part 1)
As reported in #141034 SelectionDAG::getNode had some unexpected behaviors when trying to create vectors with UNDEF elements. Since we treat both UNDEF and POISON as undefined (when using isUndef()) we can't just fold away INSERT_VECTOR_ELT/INSERT_SUBVECTOR based on isUndef(), as that could make the resulting vector more poisonous. Same kind of bug existed in DAGCombiner::visitINSERT_SUBVECTOR. Here are some examples: This fold was done even if vec[idx] was POISON: INSERT_VECTOR_ELT vec, UNDEF, idx -> vec This fold was done even if any of vec[idx..idx+size] was POISON: INSERT_SUBVECTOR vec, UNDEF, idx -> vec This fold was done even if the elements not extracted from vec could be POISON: sub = EXTRACT_SUBVECTOR vec, idx INSERT_SUBVECTOR UNDEF, sub, idx -> vec With this patch we avoid such folds unless we can prove that the result isn't more poisonous when eliminating the insert. Fixes #141034
1 parent d15b7a8 commit 47811e8

File tree

10 files changed

+285
-116
lines changed

10 files changed

+285
-116
lines changed

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,6 +1885,12 @@ LLVM_ABI SDValue peekThroughOneUseBitcasts(SDValue V);
18851885
/// If \p V is not an extracted subvector, it is returned as-is.
18861886
LLVM_ABI SDValue peekThroughExtractSubvectors(SDValue V);
18871887

1888+
/// Recursively peek through INSERT_VECTOR_ELT nodes, returning the source
1889+
/// vector operand of \p V, as long as \p V is an INSERT_VECTOR_ELT operation
1890+
/// that do not insert into any of the demanded vector elts.
1891+
LLVM_ABI SDValue peekThroughInsertVectorElt(SDValue V,
1892+
const APInt &DemandedElts);
1893+
18881894
/// Return the non-truncated source operand of \p V if it exists.
18891895
/// If \p V is not a truncation, it is returned as-is.
18901896
LLVM_ABI SDValue peekThroughTruncates(SDValue V);

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23281,6 +23281,7 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
2328123281
auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
2328223282

2328323283
// Insert into out-of-bounds element is undefined.
23284+
// Code below relies on that we handle this special case early.
2328423285
if (IndexC && VT.isFixedLengthVector() &&
2328523286
IndexC->getZExtValue() >= VT.getVectorNumElements())
2328623287
return DAG.getUNDEF(VT);
@@ -23291,14 +23292,28 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
2329123292
InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
2329223293
return InVec;
2329323294

23294-
if (!IndexC) {
23295-
// If this is variable insert to undef vector, it might be better to splat:
23296-
// inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
23297-
if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
23298-
return DAG.getSplat(VT, DL, InVal);
23299-
return SDValue();
23295+
// If this is variable insert to undef vector, it might be better to splat:
23296+
// inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
23297+
if (!IndexC && InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
23298+
return DAG.getSplat(VT, DL, InVal);
23299+
23300+
// Try to drop insert of UNDEF/POISON elements. This is also done in getNode,
23301+
// but we also do it as a DAG combine since for example simplifications into
23302+
// SPLAT_VECTOR/BUILD_VECTOR may turn poison elements into undef/zero etc, and
23303+
// then suddenly the InVec is guaranteed to not be poison.
23304+
if (InVal.isUndef()) {
23305+
if (IndexC && VT.isFixedLengthVector()) {
23306+
APInt EltMask = APInt::getOneBitSet(VT.getVectorNumElements(),
23307+
IndexC->getZExtValue());
23308+
if (DAG.isGuaranteedNotToBePoison(InVec, EltMask))
23309+
return InVec;
23310+
}
23311+
return DAG.getFreeze(InVec);
2330023312
}
2330123313

23314+
if (!IndexC)
23315+
return SDValue();
23316+
2330223317
if (VT.isScalableVector())
2330323318
return SDValue();
2330423319

@@ -27779,18 +27794,42 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
2777927794
SDValue N2 = N->getOperand(2);
2778027795
uint64_t InsIdx = N->getConstantOperandVal(2);
2778127796

27782-
// If inserting an UNDEF, just return the original vector.
27783-
if (N1.isUndef())
27784-
return N0;
27797+
// If inserting an UNDEF, just return the original vector (unless it makes the
27798+
// result more poisonous).
27799+
if (N1.isUndef()) {
27800+
if (N1.getOpcode() == ISD::POISON)
27801+
return N0;
27802+
if (VT.isFixedLengthVector()) {
27803+
unsigned SubVecNumElts = N1.getValueType().getVectorNumElements();
27804+
APInt EltMask = APInt::getBitsSet(VT.getVectorNumElements(), InsIdx,
27805+
InsIdx + SubVecNumElts);
27806+
if (DAG.isGuaranteedNotToBePoison(N0, EltMask))
27807+
return N0;
27808+
}
27809+
return DAG.getFreeze(N0);
27810+
}
2778527811

27786-
// If this is an insert of an extracted vector into an undef vector, we can
27787-
// just use the input to the extract if the types match, and can simplify
27812+
// If this is an insert of an extracted vector into an undef/poison vector, we
27813+
// can just use the input to the extract if the types match, and can simplify
2778827814
// in some cases even if they don't.
2778927815
if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
2779027816
N1.getOperand(1) == N2) {
27817+
EVT N1VT = N1.getValueType();
2779127818
EVT SrcVT = N1.getOperand(0).getValueType();
27792-
if (SrcVT == VT)
27793-
return N1.getOperand(0);
27819+
if (SrcVT == VT) {
27820+
// Need to ensure that result isn't more poisonous if skipping both the
27821+
// extract+insert.
27822+
if (N0.getOpcode() == ISD::POISON)
27823+
return N1.getOperand(0);
27824+
if (VT.isFixedLengthVector() && N1VT.isFixedLengthVector()) {
27825+
unsigned SubVecNumElts = N1VT.getVectorNumElements();
27826+
APInt EltMask = APInt::getBitsSet(VT.getVectorNumElements(), InsIdx,
27827+
InsIdx + SubVecNumElts);
27828+
if (DAG.isGuaranteedNotToBePoison(N1.getOperand(0), ~EltMask))
27829+
return N1.getOperand(0);
27830+
} else if (DAG.isGuaranteedNotToBePoison(N1.getOperand(0)))
27831+
return N1.getOperand(0);
27832+
}
2779427833
// TODO: To remove the zero check, need to adjust the offset to
2779527834
// a multiple of the new src type.
2779627835
if (isNullConstant(N2)) {

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5519,8 +5519,9 @@ bool SelectionDAG::isGuaranteedNotToBeUndefOrPoison(SDValue Op,
55195519
APInt InVecDemandedElts = DemandedElts;
55205520
InVecDemandedElts.clearBit(IndexC->getZExtValue());
55215521
if (!!InVecDemandedElts &&
5522-
!isGuaranteedNotToBeUndefOrPoison(InVec, InVecDemandedElts,
5523-
PoisonOnly, Depth + 1))
5522+
!isGuaranteedNotToBeUndefOrPoison(
5523+
peekThroughInsertVectorElt(InVec, InVecDemandedElts),
5524+
InVecDemandedElts, PoisonOnly, Depth + 1))
55245525
return false;
55255526
return true;
55265527
}
@@ -8219,23 +8220,42 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
82198220
// INSERT_VECTOR_ELT into out-of-bounds element is an UNDEF, except
82208221
// for scalable vectors where we will generate appropriate code to
82218222
// deal with out-of-bounds cases correctly.
8222-
if (N3C && N1.getValueType().isFixedLengthVector() &&
8223-
N3C->getZExtValue() >= N1.getValueType().getVectorNumElements())
8223+
if (N3C && VT.isFixedLengthVector() &&
8224+
N3C->getZExtValue() >= VT.getVectorNumElements())
82248225
return getUNDEF(VT);
82258226

82268227
// Undefined index can be assumed out-of-bounds, so that's UNDEF too.
82278228
if (N3.isUndef())
82288229
return getUNDEF(VT);
82298230

8230-
// If the inserted element is an UNDEF, just use the input vector.
8231-
if (N2.isUndef())
8231+
// If inserting poison, just use the input vector.
8232+
if (N2.getOpcode() == ISD::POISON)
82328233
return N1;
82338234

8235+
// Inserting undef into undef/poison is still undef.
8236+
if (N2.getOpcode() == ISD::UNDEF && N1.isUndef())
8237+
return getUNDEF(VT);
8238+
8239+
// If the inserted element is an UNDEF, just use the input vector.
8240+
// But not if skipping the insert could make the result more poisonous.
8241+
if (N2.isUndef()) {
8242+
if (N3C && VT.isFixedLengthVector()) {
8243+
APInt EltMask =
8244+
APInt::getOneBitSet(VT.getVectorNumElements(), N3C->getZExtValue());
8245+
if (isGuaranteedNotToBePoison(N1, EltMask))
8246+
return N1;
8247+
} else if (isGuaranteedNotToBePoison(N1))
8248+
return N1;
8249+
}
82348250
break;
82358251
}
82368252
case ISD::INSERT_SUBVECTOR: {
8237-
// Inserting undef into undef is still undef.
8238-
if (N1.isUndef() && N2.isUndef())
8253+
// If inserting poison, just use the input vector,
8254+
if (N2.getOpcode() == ISD::POISON)
8255+
return N1;
8256+
8257+
// Inserting undef into undef/poison is still undef.
8258+
if (N2.getOpcode() == ISD::UNDEF && N1.isUndef())
82398259
return getUNDEF(VT);
82408260

82418261
EVT N2VT = N2.getValueType();
@@ -8264,11 +8284,37 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
82648284
if (VT == N2VT)
82658285
return N2;
82668286

8267-
// If this is an insert of an extracted vector into an undef vector, we
8268-
// can just use the input to the extract.
8287+
// If this is an insert of an extracted vector into an undef/poison vector,
8288+
// we can just use the input to the extract. But not if skipping the
8289+
// extract+insert could make the result more poisonous.
82698290
if (N1.isUndef() && N2.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
8270-
N2.getOperand(1) == N3 && N2.getOperand(0).getValueType() == VT)
8271-
return N2.getOperand(0);
8291+
N2.getOperand(1) == N3 && N2.getOperand(0).getValueType() == VT) {
8292+
if (N1.getOpcode() == ISD::POISON)
8293+
return N2.getOperand(0);
8294+
if (VT.isFixedLengthVector() && N2VT.isFixedLengthVector()) {
8295+
unsigned LoBit = N3->getAsZExtVal();
8296+
unsigned HiBit = LoBit + N2VT.getVectorNumElements();
8297+
APInt EltMask =
8298+
APInt::getBitsSet(VT.getVectorNumElements(), LoBit, HiBit);
8299+
if (isGuaranteedNotToBePoison(N2.getOperand(0), ~EltMask))
8300+
return N2.getOperand(0);
8301+
} else if (isGuaranteedNotToBePoison(N2.getOperand(0)))
8302+
return N2.getOperand(0);
8303+
}
8304+
8305+
// If the inserted subvector is UNDEF, just use the input vector.
8306+
// But not if skipping the insert could make the result more poisonous.
8307+
if (N2.isUndef()) {
8308+
if (VT.isFixedLengthVector()) {
8309+
unsigned LoBit = N3->getAsZExtVal();
8310+
unsigned HiBit = LoBit + N2VT.getVectorNumElements();
8311+
APInt EltMask =
8312+
APInt::getBitsSet(VT.getVectorNumElements(), LoBit, HiBit);
8313+
if (isGuaranteedNotToBePoison(N1, EltMask))
8314+
return N1;
8315+
} else if (isGuaranteedNotToBePoison(N1))
8316+
return N1;
8317+
}
82728318
break;
82738319
}
82748320
case ISD::BITCAST:
@@ -12777,6 +12823,23 @@ SDValue llvm::peekThroughExtractSubvectors(SDValue V) {
1277712823
return V;
1277812824
}
1277912825

12826+
SDValue llvm::peekThroughInsertVectorElt(SDValue V, const APInt &DemandedElts) {
12827+
while (V.getOpcode() == ISD::INSERT_VECTOR_ELT) {
12828+
SDValue InVec = V.getOperand(0);
12829+
SDValue EltNo = V.getOperand(2);
12830+
EVT VT = InVec.getValueType();
12831+
auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
12832+
if (IndexC && VT.isFixedLengthVector() &&
12833+
IndexC->getAPIntValue().ult(VT.getVectorNumElements()) &&
12834+
!DemandedElts[IndexC->getZExtValue()]) {
12835+
V = InVec;
12836+
continue;
12837+
}
12838+
break;
12839+
}
12840+
return V;
12841+
}
12842+
1278012843
SDValue llvm::peekThroughTruncates(SDValue V) {
1278112844
while (V.getOpcode() == ISD::TRUNCATE)
1278212845
V = V.getOperand(0);

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3440,8 +3440,8 @@ bool TargetLowering::SimplifyDemandedVectorElts(
34403440
break;
34413441
}
34423442
case ISD::INSERT_SUBVECTOR: {
3443-
// Demand any elements from the subvector and the remainder from the src its
3444-
// inserted into.
3443+
// Demand any elements from the subvector and the remainder from the src it
3444+
// is inserted into.
34453445
SDValue Src = Op.getOperand(0);
34463446
SDValue Sub = Op.getOperand(1);
34473447
uint64_t Idx = Op.getConstantOperandVal(2);
@@ -3450,6 +3450,10 @@ bool TargetLowering::SimplifyDemandedVectorElts(
34503450
APInt DemandedSrcElts = DemandedElts;
34513451
DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
34523452

3453+
// If none of the sub operand elements are demanded, bypass the insert.
3454+
if (!DemandedSubElts)
3455+
return TLO.CombineTo(Op, Src);
3456+
34533457
APInt SubUndef, SubZero;
34543458
if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,
34553459
Depth + 1))

llvm/test/CodeGen/AArch64/arm64-build-vector.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ define void @widen_f16_build_vector(ptr %addr) {
5757
; CHECK-LABEL: widen_f16_build_vector:
5858
; CHECK: // %bb.0:
5959
; CHECK-NEXT: mov w8, #13294 // =0x33ee
60-
; CHECK-NEXT: movk w8, #13294, lsl #16
61-
; CHECK-NEXT: str w8, [x0]
60+
; CHECK-NEXT: dup v0.4h, w8
61+
; CHECK-NEXT: str s0, [x0]
6262
; CHECK-NEXT: ret
6363
store <2 x half> <half 0xH33EE, half 0xH33EE>, ptr %addr, align 2
6464
ret void

llvm/test/CodeGen/AArch64/concat-vector-add-combine.ll

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,14 @@ define i32 @combine_undef_add_8xi32(i32 %a, i32 %b, i32 %c, i32 %d) local_unname
9494
; CHECK-LABEL: combine_undef_add_8xi32:
9595
; CHECK: // %bb.0:
9696
; CHECK-NEXT: fmov s1, w0
97-
; CHECK-NEXT: movi v0.2d, #0000000000000000
97+
; CHECK-NEXT: dup v0.4s, w8
9898
; CHECK-NEXT: mov v1.s[1], w1
99-
; CHECK-NEXT: uhadd v0.4h, v0.4h, v0.4h
10099
; CHECK-NEXT: mov v1.s[2], w2
101100
; CHECK-NEXT: mov v1.s[3], w3
102-
; CHECK-NEXT: xtn v2.4h, v1.4s
103-
; CHECK-NEXT: shrn v1.4h, v1.4s, #16
104-
; CHECK-NEXT: uhadd v1.4h, v2.4h, v1.4h
105-
; CHECK-NEXT: mov v1.d[1], v0.d[0]
106-
; CHECK-NEXT: uaddlv s0, v1.8h
101+
; CHECK-NEXT: uzp2 v2.8h, v1.8h, v0.8h
102+
; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h
103+
; CHECK-NEXT: uhadd v0.8h, v0.8h, v2.8h
104+
; CHECK-NEXT: uaddlv s0, v0.8h
107105
; CHECK-NEXT: fmov w0, s0
108106
; CHECK-NEXT: ret
109107
%a1 = insertelement <8 x i32> poison, i32 %a, i32 0

llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1198,11 +1198,15 @@ define void @masked_gather_passthru(ptr %a, ptr %b, ptr %c) vscale_range(16,0) #
11981198
; CHECK-NEXT: ptrue p0.s, vl32
11991199
; CHECK-NEXT: ptrue p2.d, vl32
12001200
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
1201-
; CHECK-NEXT: ld1w { z1.s }, p0/z, [x2]
12021201
; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, #0.0
12031202
; CHECK-NEXT: ld1d { z0.d }, p2/z, [x1]
12041203
; CHECK-NEXT: punpklo p2.h, p1.b
1204+
; CHECK-NEXT: mov z1.s, p1/z, #-1 // =0xffffffffffffffff
1205+
; CHECK-NEXT: ptrue p1.s
12051206
; CHECK-NEXT: ld1w { z0.d }, p2/z, [z0.d]
1207+
; CHECK-NEXT: and z1.s, z1.s, #0x1
1208+
; CHECK-NEXT: cmpne p1.s, p1/z, z1.s, #0
1209+
; CHECK-NEXT: ld1w { z1.s }, p0/z, [x2]
12061210
; CHECK-NEXT: uzp1 z0.s, z0.s, z0.s
12071211
; CHECK-NEXT: sel z0.s, p1, z0.s, z1.s
12081212
; CHECK-NEXT: st1w { z0.s }, p0, [x0]

0 commit comments

Comments
 (0)