Skip to content

Commit bd7d333

Browse files
Make DAG combine one function again
This is so the MUL fold does not happen unless the extend fold can be performed. As otherwise a lot of code would need to be repeated to check that it can happen.
1 parent 406041f commit bd7d333

File tree

1 file changed

+16
-30
lines changed

1 file changed

+16
-30
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12504,25 +12504,17 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1250412504
}
1250512505

1250612506
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12507-
// Only perform the DAG combine if there is custom lowering provided by the
12508-
// target.
12509-
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0),
12510-
N->getOperand(1).getValueType()))
12511-
return SDValue();
12512-
12513-
if (SDValue Res = foldMulPARTIAL_REDUCE_MLA(N))
12514-
return Res;
12515-
if (SDValue Res = foldExtendPARTIAL_REDUCE_MLA(N))
12516-
return Res;
12517-
return SDValue();
12518-
}
12519-
12520-
SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
12521-
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(MulOpLHS, MulOpRHS), Splat(1)) into
12522-
// PARTIAL_REDUCE_*MLA(Acc, MulOpLHS, MulOpRHS)
12507+
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), ZEXT(MulOpRHS)),
12508+
// Splat(1)) into
12509+
// PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS).
12510+
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), SEXT(MulOpRHS)),
12511+
// Splat(1)) into
12512+
// PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS).
1252312513
SDLoc DL(N);
1252412514

12515+
SDValue Op0 = N->getOperand(0);
1252512516
SDValue Op1 = N->getOperand(1);
12517+
1252612518
if (Op1->getOpcode() != ISD::MUL)
1252712519
return SDValue();
1252812520

@@ -12531,18 +12523,8 @@ SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
1253112523
!ConstantOne.isOne())
1253212524
return SDValue();
1253312525

12534-
return DAG.getNode(N->getOpcode(), DL, N->getValueType(0), N->getOperand(0),
12535-
Op1->getOperand(0), Op1->getOperand(1));
12536-
}
12537-
12538-
SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
12539-
// Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(MulOpLHS), ZEXT(MulOpRHS)) into
12540-
// PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS) and
12541-
// PARTIAL_REDUCE_*MLA(Acc, SEXT(MulOpLHS), SEXT(MulOpRHS)) into
12542-
// PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS)
12543-
SDLoc DL(N);
12544-
SDValue ExtMulOpLHS = N->getOperand(1);
12545-
SDValue ExtMulOpRHS = N->getOperand(2);
12526+
SDValue ExtMulOpLHS = Op1->getOperand(0);
12527+
SDValue ExtMulOpRHS = Op1->getOperand(1);
1254612528
unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
1254712529
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
1254812530
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
@@ -12554,6 +12536,10 @@ SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
1255412536
EVT MulOpLHSVT = MulOpLHS.getValueType();
1255512537
if (MulOpLHSVT != MulOpRHS.getValueType())
1255612538
return SDValue();
12539+
// Only perform the DAG combine if there is custom lowering provided by the
12540+
// target
12541+
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), MulOpLHSVT))
12542+
return SDValue();
1255712543

1255812544
bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
1255912545
bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
@@ -12562,8 +12548,8 @@ SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
1256212548

1256312549
unsigned NewOpcode =
1256412550
LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12565-
return DAG.getNode(NewOpcode, DL, N->getValueType(0), N->getOperand(0),
12566-
MulOpLHS, MulOpRHS);
12551+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Op0, MulOpLHS,
12552+
MulOpRHS);
1256712553
}
1256812554

1256912555
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {

0 commit comments

Comments
 (0)