Skip to content

Commit e83eee3

Browse files
authored
[DAG] Create SDPatternMatch method m_SelectLike to match ISD::Select and ISD::VSelect (#164069)
Fixes #150019
1 parent c9fb37c commit e83eee3

File tree

3 files changed

+49
-20
lines changed

3 files changed

+49
-20
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,11 @@ m_VSelect(const T0_P &Cond, const T1_P &T, const T2_P &F) {
558558
return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::VSELECT, Cond, T, F);
559559
}
560560

561+
template <typename T0_P, typename T1_P, typename T2_P>
562+
inline auto m_SelectLike(const T0_P &Cond, const T1_P &T, const T2_P &F) {
563+
return m_AnyOf(m_Select(Cond, T, F), m_VSelect(Cond, T, F));
564+
}
565+
561566
template <typename T0_P, typename T1_P, typename T2_P>
562567
inline Result_match<0, TernaryOpc_match<T0_P, T1_P, T2_P>>
563568
m_Load(const T0_P &Ch, const T1_P &Ptr, const T2_P &Offset) {

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,28 +2476,27 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
24762476
/// masked vector operation if the target supports it.
24772477
static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
24782478
bool ShouldCommuteOperands) {
2479-
// Match a select as operand 1. The identity constant that we are looking for
2480-
// is only valid as operand 1 of a non-commutative binop.
24812479
SDValue N0 = N->getOperand(0);
24822480
SDValue N1 = N->getOperand(1);
2481+
2482+
// Match a select as operand 1. The identity constant that we are looking for
2483+
// is only valid as operand 1 of a non-commutative binop.
24832484
if (ShouldCommuteOperands)
24842485
std::swap(N0, N1);
24852486

2486-
unsigned SelOpcode = N1.getOpcode();
2487-
if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
2488-
!N1.hasOneUse())
2487+
SDValue Cond, TVal, FVal;
2488+
if (!sd_match(N1, m_OneUse(m_SelectLike(m_Value(Cond), m_Value(TVal),
2489+
m_Value(FVal)))))
24892490
return SDValue();
24902491

24912492
// We can't hoist all instructions because of immediate UB (not speculatable).
24922493
// For example div/rem by zero.
24932494
if (!DAG.isSafeToSpeculativelyExecuteNode(N))
24942495
return SDValue();
24952496

2497+
unsigned SelOpcode = N1.getOpcode();
24962498
unsigned Opcode = N->getOpcode();
24972499
EVT VT = N->getValueType(0);
2498-
SDValue Cond = N1.getOperand(0);
2499-
SDValue TVal = N1.getOperand(1);
2500-
SDValue FVal = N1.getOperand(2);
25012500
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25022501

25032502
// This transform increases uses of N0, so freeze it to be safe.
@@ -13856,12 +13855,11 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
1385613855
Opcode == ISD::ANY_EXTEND) &&
1385713856
"Expected EXTEND dag node in input!");
1385813857

13859-
if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
13860-
!N0.hasOneUse())
13858+
SDValue Cond, Op1, Op2;
13859+
if (!sd_match(N0, m_OneUse(m_SelectLike(m_Value(Cond), m_Value(Op1),
13860+
m_Value(Op2)))))
1386113861
return SDValue();
1386213862

13863-
SDValue Op1 = N0->getOperand(1);
13864-
SDValue Op2 = N0->getOperand(2);
1386513863
if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
1386613864
return SDValue();
1386713865

@@ -13883,7 +13881,7 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
1388313881

1388413882
SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
1388513883
SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
13886-
return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
13884+
return DAG.getSelect(DL, VT, Cond, Ext1, Ext2);
1388713885
}
1388813886

1388913887
/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
@@ -29620,13 +29618,14 @@ static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
2962029618
}
2962129619

2962229620
// c ? X : Y -> c ? Log2(X) : Log2(Y)
29623-
if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
29624-
Op.hasOneUse()) {
29625-
if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1),
29626-
Depth + 1, AssumeNonZero))
29627-
if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(2),
29628-
Depth + 1, AssumeNonZero))
29629-
return DAG.getSelect(DL, VT, Op.getOperand(0), LogX, LogY);
29621+
SDValue Cond, TVal, FVal;
29622+
if (sd_match(Op, m_OneUse(m_SelectLike(m_Value(Cond), m_Value(TVal),
29623+
m_Value(FVal))))) {
29624+
if (SDValue LogX =
29625+
takeInexpensiveLog2(DAG, DL, VT, TVal, Depth + 1, AssumeNonZero))
29626+
if (SDValue LogY =
29627+
takeInexpensiveLog2(DAG, DL, VT, FVal, Depth + 1, AssumeNonZero))
29628+
return DAG.getSelect(DL, VT, Cond, LogX, LogY);
2963029629
}
2963129630

2963229631
// log2(umin(X, Y)) -> umin(log2(X), log2(Y))

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,31 @@ TEST_F(SelectionDAGPatternMatchTest, matchNode) {
550550
EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value())));
551551
}
552552

553+
TEST_F(SelectionDAGPatternMatchTest, matchSelectLike) {
554+
SDLoc DL;
555+
auto Int32VT = EVT::getIntegerVT(Context, 32);
556+
auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
557+
558+
SDValue Cond = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 0, Int32VT);
559+
SDValue TVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
560+
SDValue FVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
561+
562+
SDValue VCond = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 0, VInt32VT);
563+
SDValue VTVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
564+
SDValue VFVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
565+
566+
SDValue Select = DAG->getNode(ISD::SELECT, DL, Int32VT, Cond, TVal, FVal);
567+
SDValue VSelect =
568+
DAG->getNode(ISD::VSELECT, DL, Int32VT, VCond, VTVal, VFVal);
569+
570+
using namespace SDPatternMatch;
571+
EXPECT_TRUE(sd_match(Select, m_SelectLike(m_Specific(Cond), m_Specific(TVal),
572+
m_Specific(FVal))));
573+
EXPECT_TRUE(
574+
sd_match(VSelect, m_SelectLike(m_Specific(VCond), m_Specific(VTVal),
575+
m_Specific(VFVal))));
576+
}
577+
553578
namespace {
554579
struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
555580
using SDPatternMatch::BasicMatchContext::BasicMatchContext;

0 commit comments

Comments
 (0)