Skip to content

Commit c174035

Browse files
committed
[SelectionDAG] Deal with POISON for INSERT_VECTOR_ELT/INSERT_SUBVECTOR
As reported in #141034 SelectionDAG::getNode had some unexpected behaviors when trying to create vectors with UNDEF elements. Since we treat both UNDEF and POISON as undefined (when using isUndef()) we can't just fold away INSERT_VECTOR_ELT/INSERT_SUBVECTOR based on isUndef(), as that could make the resulting vector more poisonous. Same kind of bug existed in DAGCombiner::visitINSERT_SUBVECTOR. Here are some examples: This fold was done even if vec[idx] was POISON: INSERT_VECTOR_ELT vec, UNDEF, idx -> vec This fold was done even if any of vec[idx..idx+size] was POISON: INSERT_SUBVECTOR vec, UNDEF, idx -> vec This fold was done even if the elements not extracted from vec could be POISON: sub = EXTRACT_SUBVECTOR vec, idx INSERT_SUBVECTOR UNDEF, sub, idx -> vec With this patch we avoid such folds unless we can prove that the result isn't more poisonous when eliminating the insert. Fixes #141034
1 parent ddcd3fd commit c174035

File tree

10 files changed

+284
-115
lines changed

10 files changed

+284
-115
lines changed

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,6 +1885,12 @@ LLVM_ABI SDValue peekThroughOneUseBitcasts(SDValue V);
18851885
/// If \p V is not an extracted subvector, it is returned as-is.
18861886
LLVM_ABI SDValue peekThroughExtractSubvectors(SDValue V);
18871887

1888+
/// Recursively peek through INSERT_VECTOR_ELT nodes, returning the source
1889+
/// vector operand of \p V, as long as \p V is an INSERT_VECTOR_ELT operation
1890+
/// that do not insert into any of the demanded vector elts.
1891+
LLVM_ABI SDValue peekThroughInsertVectorElt(SDValue V,
1892+
const APInt &DemandedElts);
1893+
18881894
/// Return the non-truncated source operand of \p V if it exists.
18891895
/// If \p V is not a truncation, it is returned as-is.
18901896
LLVM_ABI SDValue peekThroughTruncates(SDValue V);

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23298,6 +23298,7 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
2329823298
auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
2329923299

2330023300
// Insert into out-of-bounds element is undefined.
23301+
// Code below relies on that we handle this special case early.
2330123302
if (IndexC && VT.isFixedLengthVector() &&
2330223303
IndexC->getZExtValue() >= VT.getVectorNumElements())
2330323304
return DAG.getUNDEF(VT);
@@ -23308,14 +23309,28 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
2330823309
InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
2330923310
return InVec;
2331023311

23311-
if (!IndexC) {
23312-
// If this is variable insert to undef vector, it might be better to splat:
23313-
// inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
23314-
if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
23315-
return DAG.getSplat(VT, DL, InVal);
23316-
return SDValue();
23312+
// If this is variable insert to undef vector, it might be better to splat:
23313+
// inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
23314+
if (!IndexC && InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
23315+
return DAG.getSplat(VT, DL, InVal);
23316+
23317+
// Try to drop insert of UNDEF/POISON elements. This is also done in getNode,
23318+
// but we also do it as a DAG combine since for example simplifications into
23319+
// SPLAT_VECTOR/BUILD_VECTOR may turn poison elements into undef/zero etc, and
23320+
// then suddenly the InVec is guaranteed to not be poison.
23321+
if (InVal.isUndef()) {
23322+
if (IndexC && VT.isFixedLengthVector()) {
23323+
APInt EltMask = APInt::getOneBitSet(VT.getVectorNumElements(),
23324+
IndexC->getZExtValue());
23325+
if (DAG.isGuaranteedNotToBePoison(InVec, EltMask))
23326+
return InVec;
23327+
}
23328+
return DAG.getFreeze(InVec);
2331723329
}
2331823330

23331+
if (!IndexC)
23332+
return SDValue();
23333+
2331923334
if (VT.isScalableVector())
2332023335
return SDValue();
2332123336

@@ -27799,18 +27814,42 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
2779927814
SDValue N2 = N->getOperand(2);
2780027815
uint64_t InsIdx = N->getConstantOperandVal(2);
2780127816

27802-
// If inserting an UNDEF, just return the original vector.
27803-
if (N1.isUndef())
27804-
return N0;
27817+
// If inserting an UNDEF, just return the original vector (unless it makes the
27818+
// result more poisonous).
27819+
if (N1.isUndef()) {
27820+
if (N1.getOpcode() == ISD::POISON)
27821+
return N0;
27822+
if (VT.isFixedLengthVector()) {
27823+
unsigned SubVecNumElts = N1.getValueType().getVectorNumElements();
27824+
APInt EltMask = APInt::getBitsSet(VT.getVectorNumElements(), InsIdx,
27825+
InsIdx + SubVecNumElts);
27826+
if (DAG.isGuaranteedNotToBePoison(N0, EltMask))
27827+
return N0;
27828+
}
27829+
return DAG.getFreeze(N0);
27830+
}
2780527831

27806-
// If this is an insert of an extracted vector into an undef vector, we can
27807-
// just use the input to the extract if the types match, and can simplify
27832+
// If this is an insert of an extracted vector into an undef/poison vector, we
27833+
// can just use the input to the extract if the types match, and can simplify
2780827834
// in some cases even if they don't.
2780927835
if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
2781027836
N1.getOperand(1) == N2) {
27837+
EVT N1VT = N1.getValueType();
2781127838
EVT SrcVT = N1.getOperand(0).getValueType();
27812-
if (SrcVT == VT)
27813-
return N1.getOperand(0);
27839+
if (SrcVT == VT) {
27840+
// Need to ensure that result isn't more poisonous if skipping both the
27841+
// extract+insert.
27842+
if (N0.getOpcode() == ISD::POISON)
27843+
return N1.getOperand(0);
27844+
if (VT.isFixedLengthVector() && N1VT.isFixedLengthVector()) {
27845+
unsigned SubVecNumElts = N1VT.getVectorNumElements();
27846+
APInt EltMask = APInt::getBitsSet(VT.getVectorNumElements(), InsIdx,
27847+
InsIdx + SubVecNumElts);
27848+
if (DAG.isGuaranteedNotToBePoison(N1.getOperand(0), ~EltMask))
27849+
return N1.getOperand(0);
27850+
} else if (DAG.isGuaranteedNotToBePoison(N1.getOperand(0)))
27851+
return N1.getOperand(0);
27852+
}
2781427853
// TODO: To remove the zero check, need to adjust the offset to
2781527854
// a multiple of the new src type.
2781627855
if (isNullConstant(N2)) {

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5513,8 +5513,9 @@ bool SelectionDAG::isGuaranteedNotToBeUndefOrPoison(SDValue Op,
55135513
APInt InVecDemandedElts = DemandedElts;
55145514
InVecDemandedElts.clearBit(IndexC->getZExtValue());
55155515
if (!!InVecDemandedElts &&
5516-
!isGuaranteedNotToBeUndefOrPoison(InVec, InVecDemandedElts,
5517-
PoisonOnly, Depth + 1))
5516+
!isGuaranteedNotToBeUndefOrPoison(
5517+
peekThroughInsertVectorElt(InVec, InVecDemandedElts),
5518+
InVecDemandedElts, PoisonOnly, Depth + 1))
55185519
return false;
55195520
return true;
55205521
}
@@ -8215,23 +8216,42 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
82158216
// INSERT_VECTOR_ELT into out-of-bounds element is an UNDEF, except
82168217
// for scalable vectors where we will generate appropriate code to
82178218
// deal with out-of-bounds cases correctly.
8218-
if (N3C && N1.getValueType().isFixedLengthVector() &&
8219-
N3C->getZExtValue() >= N1.getValueType().getVectorNumElements())
8219+
if (N3C && VT.isFixedLengthVector() &&
8220+
N3C->getZExtValue() >= VT.getVectorNumElements())
82208221
return getUNDEF(VT);
82218222

82228223
// Undefined index can be assumed out-of-bounds, so that's UNDEF too.
82238224
if (N3.isUndef())
82248225
return getUNDEF(VT);
82258226

8226-
// If the inserted element is an UNDEF, just use the input vector.
8227-
if (N2.isUndef())
8227+
// If inserting poison, just use the input vector.
8228+
if (N2.getOpcode() == ISD::POISON)
82288229
return N1;
82298230

8231+
// Inserting undef into undef/poison is still undef.
8232+
if (N2.getOpcode() == ISD::UNDEF && N1.isUndef())
8233+
return getUNDEF(VT);
8234+
8235+
// If the inserted element is an UNDEF, just use the input vector.
8236+
// But not if skipping the insert could make the result more poisonous.
8237+
if (N2.isUndef()) {
8238+
if (N3C && VT.isFixedLengthVector()) {
8239+
APInt EltMask =
8240+
APInt::getOneBitSet(VT.getVectorNumElements(), N3C->getZExtValue());
8241+
if (isGuaranteedNotToBePoison(N1, EltMask))
8242+
return N1;
8243+
} else if (isGuaranteedNotToBePoison(N1))
8244+
return N1;
8245+
}
82308246
break;
82318247
}
82328248
case ISD::INSERT_SUBVECTOR: {
8233-
// Inserting undef into undef is still undef.
8234-
if (N1.isUndef() && N2.isUndef())
8249+
// If inserting poison, just use the input vector,
8250+
if (N2.getOpcode() == ISD::POISON)
8251+
return N1;
8252+
8253+
// Inserting undef into undef/poison is still undef.
8254+
if (N2.getOpcode() == ISD::UNDEF && N1.isUndef())
82358255
return getUNDEF(VT);
82368256

82378257
EVT N2VT = N2.getValueType();
@@ -8260,11 +8280,37 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
82608280
if (VT == N2VT)
82618281
return N2;
82628282

8263-
// If this is an insert of an extracted vector into an undef vector, we
8264-
// can just use the input to the extract.
8283+
// If this is an insert of an extracted vector into an undef/poison vector,
8284+
// we can just use the input to the extract. But not if skipping the
8285+
// extract+insert could make the result more poisonous.
82658286
if (N1.isUndef() && N2.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
8266-
N2.getOperand(1) == N3 && N2.getOperand(0).getValueType() == VT)
8267-
return N2.getOperand(0);
8287+
N2.getOperand(1) == N3 && N2.getOperand(0).getValueType() == VT) {
8288+
if (N1.getOpcode() == ISD::POISON)
8289+
return N2.getOperand(0);
8290+
if (VT.isFixedLengthVector() && N2VT.isFixedLengthVector()) {
8291+
unsigned LoBit = N3->getAsZExtVal();
8292+
unsigned HiBit = LoBit + N2VT.getVectorNumElements();
8293+
APInt EltMask =
8294+
APInt::getBitsSet(VT.getVectorNumElements(), LoBit, HiBit);
8295+
if (isGuaranteedNotToBePoison(N2.getOperand(0), ~EltMask))
8296+
return N2.getOperand(0);
8297+
} else if (isGuaranteedNotToBePoison(N2.getOperand(0)))
8298+
return N2.getOperand(0);
8299+
}
8300+
8301+
// If the inserted subvector is UNDEF, just use the input vector.
8302+
// But not if skipping the insert could make the result more poisonous.
8303+
if (N2.isUndef()) {
8304+
if (VT.isFixedLengthVector()) {
8305+
unsigned LoBit = N3->getAsZExtVal();
8306+
unsigned HiBit = LoBit + N2VT.getVectorNumElements();
8307+
APInt EltMask =
8308+
APInt::getBitsSet(VT.getVectorNumElements(), LoBit, HiBit);
8309+
if (isGuaranteedNotToBePoison(N1, EltMask))
8310+
return N1;
8311+
} else if (isGuaranteedNotToBePoison(N1))
8312+
return N1;
8313+
}
82688314
break;
82698315
}
82708316
case ISD::BITCAST:
@@ -12729,6 +12775,23 @@ SDValue llvm::peekThroughExtractSubvectors(SDValue V) {
1272912775
return V;
1273012776
}
1273112777

12778+
SDValue llvm::peekThroughInsertVectorElt(SDValue V, const APInt &DemandedElts) {
12779+
while (V.getOpcode() == ISD::INSERT_VECTOR_ELT) {
12780+
SDValue InVec = V.getOperand(0);
12781+
SDValue EltNo = V.getOperand(2);
12782+
EVT VT = InVec.getValueType();
12783+
auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
12784+
if (IndexC && VT.isFixedLengthVector() &&
12785+
IndexC->getAPIntValue().ult(VT.getVectorNumElements()) &&
12786+
!DemandedElts[IndexC->getZExtValue()]) {
12787+
V = InVec;
12788+
continue;
12789+
}
12790+
break;
12791+
}
12792+
return V;
12793+
}
12794+
1273212795
SDValue llvm::peekThroughTruncates(SDValue V) {
1273312796
while (V.getOpcode() == ISD::TRUNCATE)
1273412797
V = V.getOperand(0);

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3446,8 +3446,8 @@ bool TargetLowering::SimplifyDemandedVectorElts(
34463446
break;
34473447
}
34483448
case ISD::INSERT_SUBVECTOR: {
3449-
// Demand any elements from the subvector and the remainder from the src its
3450-
// inserted into.
3449+
// Demand any elements from the subvector and the remainder from the src it
3450+
// is inserted into.
34513451
SDValue Src = Op.getOperand(0);
34523452
SDValue Sub = Op.getOperand(1);
34533453
uint64_t Idx = Op.getConstantOperandVal(2);
@@ -3456,6 +3456,10 @@ bool TargetLowering::SimplifyDemandedVectorElts(
34563456
APInt DemandedSrcElts = DemandedElts;
34573457
DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
34583458

3459+
// If none of the sub operand elements are demanded, bypass the insert.
3460+
if (!DemandedSubElts)
3461+
return TLO.CombineTo(Op, Src);
3462+
34593463
APInt SubUndef, SubZero;
34603464
if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,
34613465
Depth + 1))

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15357,7 +15357,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
1535715357
for (unsigned i = 0; i < NumElts; ++i) {
1535815358
SDValue V = Op.getOperand(i);
1535915359
SDValue LaneIdx = DAG.getConstant(i, DL, MVT::i64);
15360-
if (!isIntOrFPConstant(V))
15360+
if (!isIntOrFPConstant(V) && !V.isUndef())
1536115361
// Note that type legalization likely mucked about with the VT of the
1536215362
// source operand, so we may have to convert it here before inserting.
1536315363
Val = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Val, V, LaneIdx);

llvm/test/CodeGen/AArch64/concat-vector-add-combine.ll

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,14 @@ define i32 @combine_undef_add_8xi32(i32 %a, i32 %b, i32 %c, i32 %d) local_unname
9494
; CHECK-LABEL: combine_undef_add_8xi32:
9595
; CHECK: // %bb.0:
9696
; CHECK-NEXT: fmov s1, w0
97-
; CHECK-NEXT: movi v0.2d, #0000000000000000
97+
; CHECK-NEXT: dup v0.4s, w8
9898
; CHECK-NEXT: mov v1.s[1], w1
99-
; CHECK-NEXT: uhadd v0.4h, v0.4h, v0.4h
10099
; CHECK-NEXT: mov v1.s[2], w2
101100
; CHECK-NEXT: mov v1.s[3], w3
102-
; CHECK-NEXT: xtn v2.4h, v1.4s
103-
; CHECK-NEXT: shrn v1.4h, v1.4s, #16
104-
; CHECK-NEXT: uhadd v1.4h, v2.4h, v1.4h
105-
; CHECK-NEXT: mov v1.d[1], v0.d[0]
106-
; CHECK-NEXT: uaddlv s0, v1.8h
101+
; CHECK-NEXT: uzp2 v2.8h, v1.8h, v0.8h
102+
; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h
103+
; CHECK-NEXT: uhadd v0.8h, v0.8h, v2.8h
104+
; CHECK-NEXT: uaddlv s0, v0.8h
107105
; CHECK-NEXT: fmov w0, s0
108106
; CHECK-NEXT: ret
109107
%a1 = insertelement <8 x i32> poison, i32 %a, i32 0

llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1198,11 +1198,15 @@ define void @masked_gather_passthru(ptr %a, ptr %b, ptr %c) vscale_range(16,0) #
11981198
; CHECK-NEXT: ptrue p0.s, vl32
11991199
; CHECK-NEXT: ptrue p2.d, vl32
12001200
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
1201-
; CHECK-NEXT: ld1w { z1.s }, p0/z, [x2]
12021201
; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, #0.0
12031202
; CHECK-NEXT: ld1d { z0.d }, p2/z, [x1]
12041203
; CHECK-NEXT: punpklo p2.h, p1.b
1204+
; CHECK-NEXT: mov z1.s, p1/z, #-1 // =0xffffffffffffffff
1205+
; CHECK-NEXT: ptrue p1.s
12051206
; CHECK-NEXT: ld1w { z0.d }, p2/z, [z0.d]
1207+
; CHECK-NEXT: and z1.s, z1.s, #0x1
1208+
; CHECK-NEXT: cmpne p1.s, p1/z, z1.s, #0
1209+
; CHECK-NEXT: ld1w { z1.s }, p0/z, [x2]
12061210
; CHECK-NEXT: uzp1 z0.s, z0.s, z0.s
12071211
; CHECK-NEXT: sel z0.s, p1, z0.s, z1.s
12081212
; CHECK-NEXT: st1w { z0.s }, p0, [x0]

0 commit comments

Comments
 (0)