Skip to content

Commit 7814fce

Browse files
Simplify DAG combine
1 parent e694bcf commit 7814fce

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12514,11 +12514,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1251412514
SDValue Op1 = N->getOperand(1);
1251512515
SDValue Op2 = N->getOperand(2);
1251612516

12517-
if (Op1->getOpcode() != ISD::MUL)
12518-
return SDValue();
12519-
1252012517
APInt ConstantOne;
12521-
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12518+
if (Op1->getOpcode() != ISD::MUL ||
12519+
!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
1252212520
!ConstantOne.isOne())
1252312521
return SDValue();
1252412522

@@ -12529,29 +12527,28 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1252912527
if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
1253012528
return SDValue();
1253112529

12530+
// For a 2-stage extend the signedness of both of the extends must be the
12531+
// same. This is so the node can be folded into only a signed or unsigned
12532+
// node.
1253212533
SDValue LHSExtOp = LHS->getOperand(0);
1253312534
SDValue RHSExtOp = RHS->getOperand(0);
1253412535
EVT LHSExtOpVT = LHSExtOp.getValueType();
12535-
if (LHSExtOpVT != RHSExtOp.getValueType())
12536+
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
1253612537
return SDValue();
1253712538

1253812539
// FIXME: Add a check to only perform the DAG combine if there is lowering
1253912540
// provided by the target
1254012541

12541-
bool LHSIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12542-
bool RHSIsSigned = RHSOpcode == ISD::SIGN_EXTEND;
12543-
if (LHSIsSigned != RHSIsSigned)
12544-
return SDValue();
12542+
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
1254512543

1254612544
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
1254712545
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12548-
if (LHSIsSigned != NodeIsSigned &&
12549-
(Op1.getValueType().getVectorElementType() != AccElemVT ||
12550-
Op2.getValueType().getVectorElementType() != AccElemVT))
12546+
if (ExtIsSigned != NodeIsSigned &&
12547+
Op1.getValueType().getVectorElementType() != AccElemVT)
1255112548
return SDValue();
1255212549

1255312550
unsigned NewOpcode =
12554-
LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12551+
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
1255512552
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
1255612553
RHSExtOp);
1255712554
}

0 commit comments

Comments
 (0)