Skip to content

Commit ce0b098

Browse files
Add an additional check into the DAG combine
1 parent 5ca85da commit ce0b098

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12510,14 +12510,15 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1251012510
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1251112511
SDLoc DL(N);
1251212512

12513-
SDValue Op0 = N->getOperand(0);
12513+
SDValue Acc = N->getOperand(0);
1251412514
SDValue Op1 = N->getOperand(1);
12515+
SDValue Op2 = N->getOperand(2);
1251512516

1251612517
if (Op1->getOpcode() != ISD::MUL)
1251712518
return SDValue();
1251812519

1251912520
APInt ConstantOne;
12520-
if (!ISD::isConstantSplatVector(N->getOperand(2).getNode(), ConstantOne) ||
12521+
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
1252112522
!ConstantOne.isOne())
1252212523
return SDValue();
1252312524

@@ -12542,9 +12543,16 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1254212543
if (LHSIsSigned != RHSIsSigned)
1254312544
return SDValue();
1254412545

12546+
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12547+
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12548+
if (LHSIsSigned != NodeIsSigned &&
12549+
(Op1.getValueType().getVectorElementType() != AccElemVT ||
12550+
Op2.getValueType().getVectorElementType() != AccElemVT))
12551+
return SDValue();
12552+
1254512553
unsigned NewOpcode =
1254612554
LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12547-
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Op0, LHSExtOp,
12555+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
1254812556
RHSExtOp);
1254912557
}
1255012558

0 commit comments

Comments
 (0)