@@ -7197,6 +7197,27 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
71977197 }
71987198 }
71997199
7200+ // Handle fma/fmad special cases.
7201+ if (Opcode == ISD::FMA || Opcode == ISD::FMAD) {
7202+ assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
7203+ assert(Ops[0].getValueType() == VT && Ops[1].getValueType() == VT &&
7204+ Ops[2].getValueType() == VT && "FMA types must match!");
7205+ ConstantFPSDNode *C1 = dyn_cast<ConstantFPSDNode>(Ops[0]);
7206+ ConstantFPSDNode *C2 = dyn_cast<ConstantFPSDNode>(Ops[1]);
7207+ ConstantFPSDNode *C3 = dyn_cast<ConstantFPSDNode>(Ops[2]);
7208+ if (C1 && C2 && C3) {
7209+ APFloat V1 = C1->getValueAPF();
7210+ const APFloat &V2 = C2->getValueAPF();
7211+ const APFloat &V3 = C3->getValueAPF();
7212+ if (Opcode == ISD::FMAD) {
7213+ V1.multiply(V2, APFloat::rmNearestTiesToEven);
7214+ V1.add(V3, APFloat::rmNearestTiesToEven);
7215+ } else
7216+ V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
7217+ return getConstantFP(V1, DL, VT);
7218+ }
7219+ }
7220+
72007221 // This is for vector folding only from here on.
72017222 if (!VT.isVector())
72027223 return SDValue();
@@ -8159,33 +8180,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
81598180 "Operand is DELETED_NODE!");
81608181 // Perform various simplifications.
81618182 switch (Opcode) {
8162- case ISD::FMA:
8163- case ISD::FMAD: {
8164- assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
8165- assert(N1.getValueType() == VT && N2.getValueType() == VT &&
8166- N3.getValueType() == VT && "FMA types must match!");
8167- ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
8168- ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(N2);
8169- ConstantFPSDNode *N3CFP = dyn_cast<ConstantFPSDNode>(N3);
8170- if (N1CFP && N2CFP && N3CFP) {
8171- APFloat V1 = N1CFP->getValueAPF();
8172- const APFloat &V2 = N2CFP->getValueAPF();
8173- const APFloat &V3 = N3CFP->getValueAPF();
8174- if (Opcode == ISD::FMAD) {
8175- V1.multiply(V2, APFloat::rmNearestTiesToEven);
8176- V1.add(V3, APFloat::rmNearestTiesToEven);
8177- } else
8178- V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
8179- return getConstantFP(V1, DL, VT);
8180- }
8181- break;
8182- }
8183- case ISD::FSHL:
8184- case ISD::FSHR:
8185- // Constant folding.
8186- if (SDValue V = FoldConstantArithmetic(Opcode, DL, VT, {N1, N2, N3}))
8187- return V;
8188- break;
81898183 case ISD::BUILD_VECTOR: {
81908184 // Attempt to simplify BUILD_VECTOR.
81918185 SDValue Ops[] = {N1, N2, N3};
@@ -8211,12 +8205,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
82118205 // Use FoldSetCC to simplify SETCC's.
82128206 if (SDValue V = FoldSetCC(VT, N1, N2, cast<CondCodeSDNode>(N3)->get(), DL))
82138207 return V;
8214- // Vector constant folding.
8215- SDValue Ops[] = {N1, N2, N3};
8216- if (SDValue V = FoldConstantArithmetic(Opcode, DL, VT, Ops)) {
8217- NewSDValueDbgMsg(V, "New node vector constant folding: ", this);
8218- return V;
8219- }
82208208 break;
82218209 }
82228210 case ISD::SELECT:
@@ -8352,6 +8340,20 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
83528340 }
83538341 }
83548342
8343+ // Perform trivial constant folding for arithmetic operators.
8344+ switch (Opcode) {
8345+ case ISD::FMA:
8346+ case ISD::FMAD:
8347+ case ISD::SETCC:
8348+ case ISD::BITCAST:
8349+ case ISD::FSHL:
8350+ case ISD::FSHR:
8351+ if (SDValue SV =
8352+ FoldConstantArithmetic(Opcode, DL, VT, {N1, N2, N3}, Flags))
8353+ return SV;
8354+ break;
8355+ }
8356+
83558357 // Memoize node if it doesn't produce a glue result.
83568358 SDNode *N;
83578359 SDVTList VTs = getVTList(VT);
0 commit comments