From 0e45daae938a2e20832cb8169473da53f629d7ff Mon Sep 17 00:00:00 2001 From: Piotr Fusik Date: Mon, 17 Feb 2025 15:48:12 +0100 Subject: [PATCH 1/2] [SelectionDAG][NFC] Refactor duplicate code into SDNode::bitcastToAPInt() --- llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 10 +++++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 6 +-- .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 38 ++++++------------- 3 files changed, 24 insertions(+), 30 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 6eff6bfe8d5b1..75c4fabe03dd4 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -989,6 +989,8 @@ END_TWO_BYTE_PACK() /// Helper method returns the APInt value of a ConstantSDNode. inline const APInt &getAsAPIntVal() const; + inline std::optional bitcastToAPInt() const; + const SDValue &getOperand(unsigned Num) const { assert(Num < NumOperands && "Invalid child # of SDNode!"); return OperandList[Num]; @@ -1785,6 +1787,14 @@ class ConstantFPSDNode : public SDNode { } }; +std::optional SDNode::bitcastToAPInt() const { + if (auto *CN = dyn_cast(this)) + return CN->getAPIntValue(); + if (auto *CFPN = dyn_cast(this)) + return CFPN->getValueAPF().bitcastToAPInt(); + return std::nullopt; +} + /// Returns true if \p V is a constant integer zero. bool isNullConstant(SDValue V); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index c6fd72b6b76f4..59f5d15cfed30 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -27420,10 +27420,8 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) { } APInt Bits; - if (auto *Cst = dyn_cast(Elt)) - Bits = Cst->getAPIntValue(); - else if (auto *CstFP = dyn_cast(Elt)) - Bits = CstFP->getValueAPF().bitcastToAPInt(); + if (auto OptBits = Elt->bitcastToAPInt()) + Bits = *std::move(OptBits); else return SDValue(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 9d2f87497d6fa..531314c5bfd07 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -152,14 +152,10 @@ bool ConstantFPSDNode::isValueValidForType(EVT VT, bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) { if (N->getOpcode() == ISD::SPLAT_VECTOR) { - unsigned EltSize = - N->getValueType(0).getVectorElementType().getSizeInBits(); - if (auto *Op0 = dyn_cast(N->getOperand(0))) { - SplatVal = Op0->getAPIntValue().trunc(EltSize); - return true; - } - if (auto *Op0 = dyn_cast(N->getOperand(0))) { - SplatVal = Op0->getValueAPF().bitcastToAPInt().trunc(EltSize); + if (auto OptAPInt = N->getOperand(0)->bitcastToAPInt()) { + unsigned EltSize = + N->getValueType(0).getVectorElementType().getSizeInBits(); + SplatVal = OptAPInt->trunc(EltSize); return true; } } @@ -215,12 +211,9 @@ bool ISD::isConstantSplatVectorAllOnes(const SDNode *N, bool BuildVectorOnly) { // we care if the resultant vector is all ones, not whether the individual // constants are. SDValue NotZero = N->getOperand(i); - unsigned EltSize = N->getValueType(0).getScalarSizeInBits(); - if (ConstantSDNode *CN = dyn_cast(NotZero)) { - if (CN->getAPIntValue().countr_one() < EltSize) - return false; - } else if (ConstantFPSDNode *CFPN = dyn_cast(NotZero)) { - if (CFPN->getValueAPF().bitcastToAPInt().countr_one() < EltSize) + if (auto OptAPInt = NotZero->bitcastToAPInt()) { + unsigned EltSize = N->getValueType(0).getScalarSizeInBits(); + if (OptAPInt->countr_one() < EltSize) return false; } else return false; @@ -259,12 +252,9 @@ bool ISD::isConstantSplatVectorAllZeros(const SDNode *N, bool BuildVectorOnly) { // We only want to check enough bits to cover the vector elements, because // we care if the resultant vector is all zeros, not whether the individual // constants are. - unsigned EltSize = N->getValueType(0).getScalarSizeInBits(); - if (ConstantSDNode *CN = dyn_cast(Op)) { - if (CN->getAPIntValue().countr_zero() < EltSize) - return false; - } else if (ConstantFPSDNode *CFPN = dyn_cast(Op)) { - if (CFPN->getValueAPF().bitcastToAPInt().countr_zero() < EltSize) + if (auto OptAPInt = Op->bitcastToAPInt()) { + unsigned EltSize = N->getValueType(0).getScalarSizeInBits(); + if (OptAPInt->countr_zero() < EltSize) return false; } else return false; @@ -3434,13 +3424,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, KnownBits Known(BitWidth); // Don't know anything. - if (auto *C = dyn_cast(Op)) { + if (auto OptAPInt = Op->bitcastToAPInt()) { // We know all of the bits for a constant! - return KnownBits::makeConstant(C->getAPIntValue()); - } - if (auto *C = dyn_cast(Op)) { - // We know all of the bits for a constant fp! - return KnownBits::makeConstant(C->getValueAPF().bitcastToAPInt()); + return KnownBits::makeConstant(*std::move(OptAPInt)); } if (Depth >= MaxRecursionDepth) From 4b3731cc2e48fa9b28fc3ec8639372dfea0ea704 Mon Sep 17 00:00:00 2001 From: Piotr Fusik Date: Wed, 19 Feb 2025 11:42:30 +0100 Subject: [PATCH 2/2] [RISCV] Apply a review suggestion --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 59f5d15cfed30..66d15fcee775c 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -27419,21 +27419,20 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) { continue; } - APInt Bits; - if (auto OptBits = Elt->bitcastToAPInt()) - Bits = *std::move(OptBits); - else + std::optional Bits = Elt->bitcastToAPInt(); + if (!Bits) return SDValue(); // Extract the sub element from the constant bit mask. if (DAG.getDataLayout().isBigEndian()) - Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits); + *Bits = + Bits->extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits); else - Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits); + *Bits = Bits->extractBits(NumSubBits, SubIdx * NumSubBits); - if (Bits.isAllOnes()) + if (Bits->isAllOnes()) Indices.push_back(i); - else if (Bits == 0) + else if (*Bits == 0) Indices.push_back(i + NumSubElts); else return SDValue();