Skip to content

Commit 6d6c250

Browse files
committed
fix: modify review for code details
1 parent f23969c commit 6d6c250

File tree

5 files changed

+28
-22
lines changed

5 files changed

+28
-22
lines changed

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,8 +1937,14 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
19371937
/// Does not permit build vector implicit truncation.
19381938
LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
19391939

1940+
/// Return true if the value is a constant 1 integer or a splatted vector of a
1941+
/// constant 1 integer (with no undefs).
1942+
/// Does not permit build vector implicit truncation.
19401943
LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
19411944

1945+
/// Return true if the value is a constant 0 integer or a splatted vector of a
1946+
/// constant 0 integer (with no undefs).
1947+
/// Does not permit build vector implicit truncation.
19421948
LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
19431949

19441950
/// Return true if \p V is either a integer or FP constant.

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4281,7 +4281,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
42814281
return V;
42824282

42834283
// (A - B) - 1 -> add (xor B, -1), A
4284-
if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true))))
4284+
if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))),
4285+
m_One(/*AllowUndefs=*/true))))
42854286
return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
42864287

42874288
// Look for:

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12570,9 +12570,9 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
1257012570
}
1257112571

1257212572
bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) {
12573-
N = peekThroughBitcasts(N);
1257412573
ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
12575-
return C && C->getAPIntValue() == 1;
12574+
return C && APInt::isSameValue(C->getAPIntValue(),
12575+
APInt(C->getAPIntValue().getBitWidth(), 1));
1257612576
}
1257712577

1257812578
bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57925,20 +57925,22 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
5792557925
}
5792657926
}
5792757927

57928-
SDValue X, Y;
57929-
5793057928
// add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0)
5793157929
// iff X and Y won't overflow.
57932-
if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) &&
57933-
sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) &&
57934-
DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) {
57935-
MVT OpVT = X.getSimpleValueType();
57936-
SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y);
57937-
return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
57938-
getZeroVector(OpVT, Subtarget, DAG, DL));
57930+
if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW &&
57931+
ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) &&
57932+
ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) {
57933+
if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) {
57934+
MVT OpVT = Op0.getOperand(1).getSimpleValueType();
57935+
SDValue Sum =
57936+
DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0));
57937+
return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
57938+
getZeroVector(OpVT, Subtarget, DAG, DL));
57939+
}
5793957940
}
5794057941

5794157942
if (VT.isVector()) {
57943+
SDValue X, Y;
5794257944
EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
5794357945
VT.getVectorElementCount());
5794457946

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) {
867867
SDValue SplatVal = DAG->getConstant(1, DL, MVT::i32);
868868
SDValue VecSplat = DAG->getSplatBuildVector(VecVT, DL, SplatVal);
869869
SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, VecSplat);
870-
EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_One()));
870+
EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_One()));
871871
}
872872

873873
// m_AllOnes: splat vector of -1 → bitcast
@@ -887,9 +887,8 @@ TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) {
887887
SmallVector<SDValue, 4> Ops(4, Zero);
888888
Ops[2] = Undef;
889889
SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops);
890-
SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, Vec);
891-
EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_Zero()));
892-
EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_Zero(true)));
890+
EXPECT_FALSE(sd_match(Vec, DAG.get(), m_Zero()));
891+
EXPECT_TRUE(sd_match(Vec, DAG.get(), m_Zero(true)));
893892
}
894893

895894
{
@@ -898,9 +897,8 @@ TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) {
898897
SmallVector<SDValue, 4> Ops(4, One);
899898
Ops[1] = Undef;
900899
SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops);
901-
SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, Vec);
902-
EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_One()));
903-
EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_One(true)));
900+
EXPECT_FALSE(sd_match(Vec, DAG.get(), m_One()));
901+
EXPECT_TRUE(sd_match(Vec, DAG.get(), m_One(true)));
904902
}
905903

906904
{
@@ -909,8 +907,7 @@ TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) {
909907
SmallVector<SDValue, 4> Ops(4, AllOnes);
910908
Ops[0] = Undef;
911909
SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops);
912-
SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, Vec);
913-
EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_AllOnes()));
914-
EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_AllOnes(true)));
910+
EXPECT_FALSE(sd_match(Vec, DAG.get(), m_AllOnes()));
911+
EXPECT_TRUE(sd_match(Vec, DAG.get(), m_AllOnes(true)));
915912
}
916913
}

0 commit comments

Comments
 (0)