Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/SDPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,11 @@ m_VSelect(const T0_P &Cond, const T1_P &T, const T2_P &F) {
return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::VSELECT, Cond, T, F);
}

template <typename T0_P, typename T1_P, typename T2_P>
inline auto m_SelectLike(const T0_P &Cond, const T1_P &T, const T2_P &F) {
return m_AnyOf(m_Select(Cond, T, F), m_VSelect(Cond, T, F));
}

template <typename T0_P, typename T1_P, typename T2_P>
inline Result_match<0, TernaryOpc_match<T0_P, T1_P, T2_P>>
m_Load(const T0_P &Ch, const T1_P &Ptr, const T2_P &Offset) {
Expand Down
41 changes: 21 additions & 20 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2476,28 +2476,28 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
/// masked vector operation if the target supports it.
static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
bool ShouldCommuteOperands) {
// Match a select as operand 1. The identity constant that we are looking for
// is only valid as operand 1 of a non-commutative binop.
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);

// Match a select as operand 1. The identity constant that we are looking for
// is only valid as operand 1 of a non-commutative binop.
if (ShouldCommuteOperands)
std::swap(N0, N1);

unsigned SelOpcode = N1.getOpcode();
if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
!N1.hasOneUse())
SDValue Cond, TVal, FVal;
if (!sd_match(N1, m_OneUse(m_SelectLike(m_Value(Cond), m_Value(TVal),
m_Value(FVal))))) {
return SDValue();
}

// We can't hoist all instructions because of immediate UB (not speculatable).
// For example div/rem by zero.
if (!DAG.isSafeToSpeculativelyExecuteNode(N))
return SDValue();

unsigned SelOpcode = N1.getOpcode();
unsigned Opcode = N->getOpcode();
EVT VT = N->getValueType(0);
SDValue Cond = N1.getOperand(0);
SDValue TVal = N1.getOperand(1);
SDValue FVal = N1.getOperand(2);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();

// This transform increases uses of N0, so freeze it to be safe.
Expand Down Expand Up @@ -13856,12 +13856,12 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
Opcode == ISD::ANY_EXTEND) &&
"Expected EXTEND dag node in input!");

if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
!N0.hasOneUse())
SDValue Cond, Op1, Op2;
if (!sd_match(N0, m_OneUse(m_SelectLike(m_Value(Cond), m_Value(Op1),
m_Value(Op2))))) {
return SDValue();
}

SDValue Op1 = N0->getOperand(1);
SDValue Op2 = N0->getOperand(2);
if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
return SDValue();

Expand All @@ -13883,7 +13883,7 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,

SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
return DAG.getSelect(DL, VT, Cond, Ext1, Ext2);
}

/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
Expand Down Expand Up @@ -29617,13 +29617,14 @@ static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
}

// c ? X : Y -> c ? Log2(X) : Log2(Y)
if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
Op.hasOneUse()) {
if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1),
Depth + 1, AssumeNonZero))
if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(2),
Depth + 1, AssumeNonZero))
return DAG.getSelect(DL, VT, Op.getOperand(0), LogX, LogY);
SDValue Cond, TVal, FVal;
if (sd_match(Op, m_OneUse(m_SelectLike(m_Value(Cond), m_Value(TVal),
m_Value(FVal))))) {
if (SDValue LogX =
takeInexpensiveLog2(DAG, DL, VT, TVal, Depth + 1, AssumeNonZero))
if (SDValue LogY =
takeInexpensiveLog2(DAG, DL, VT, FVal, Depth + 1, AssumeNonZero))
return DAG.getSelect(DL, VT, Cond, LogX, LogY);
}

// log2(umin(X, Y)) -> umin(log2(X), log2(Y))
Expand Down
25 changes: 25 additions & 0 deletions llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,31 @@ TEST_F(SelectionDAGPatternMatchTest, matchNode) {
EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value())));
}

TEST_F(SelectionDAGPatternMatchTest, matchSelectLike) {
SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);
auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);

SDValue Cond = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 0, Int32VT);
SDValue TVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue FVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);

SDValue VCond = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 0, VInt32VT);
SDValue VTVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
SDValue VFVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);

SDValue Select = DAG->getNode(ISD::SELECT, DL, Int32VT, Cond, TVal, FVal);
SDValue VSelect =
DAG->getNode(ISD::VSELECT, DL, Int32VT, VCond, VTVal, VFVal);

using namespace SDPatternMatch;
EXPECT_TRUE(sd_match(Select, m_SelectLike(m_Specific(Cond), m_Specific(TVal),
m_Specific(FVal))));
EXPECT_TRUE(
sd_match(VSelect, m_SelectLike(m_Specific(VCond), m_Specific(VTVal),
m_Specific(VFVal))));
}

namespace {
struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
using SDPatternMatch::BasicMatchContext::BasicMatchContext;
Expand Down