Skip to content

Commit 0f0c161

Browse files
committed
[SelectionDAG] Share code for two of our multiply expansions. NFC
ExpandIntRes_MUL and forceExpandWideMul have very similar code. ExpandIntRes_MUL calculates Lo and Hi half result from the 2 sources with Hi and Lo halves. forceExpandWideMul calculates the Lo and Hi half of the full product of 2 values. The only differences are that forceExpandWideMul uses ISD::SRA instead of ISD::SRL for a signed wide multiply. ExpandIntRes_MUL needs 2 additionals multiplies and 2 adds to multiply HiRHS*LHS and HiLHS*RHS and add them to Hi. This patch introduces a new function that takes HiLHS and HiRHS as optional values. If they are not null, they will be used in the calculation of the Hi half. The Signed flag can only be set when HiLHS/HiRHS are null.
1 parent 3a608ef commit 0f0c161

File tree

3 files changed

+73
-77
lines changed

3 files changed

+73
-77
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5499,6 +5499,15 @@ class TargetLowering : public TargetLoweringBase {
54995499
bool expandMULO(SDNode *Node, SDValue &Result, SDValue &Overflow,
55005500
SelectionDAG &DAG) const;
55015501

5502+
/// Calculate the product twice the width of LHS and RHS. If HiLHS/HiRHS are
5503+
/// non-null they will be included in the multiplication. The expansion works
5504+
/// by splitting the 2 inputs into 4 pieces that we can multiply and add
5505+
/// together without neding MULH or MUL_LOHI.
5506+
void forceExpandMUL(SelectionDAG &DAG, const SDLoc &dl, bool Signed,
5507+
SDValue &Lo, SDValue &Hi, SDValue LHS, SDValue RHS,
5508+
SDValue HiLHS = SDValue(),
5509+
SDValue HiRHS = SDValue()) const;
5510+
55025511
/// Calculate full product of LHS and RHS either via a libcall or through
55035512
/// brute force expansion of the multiplication. The expansion works by
55045513
/// splitting the 2 inputs into 4 pieces that we can multiply and add together

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4294,44 +4294,7 @@ void DAGTypeLegalizer::ExpandIntRes_MUL(SDNode *N,
42944294
LC = RTLIB::MUL_I128;
42954295

42964296
if (LC == RTLIB::UNKNOWN_LIBCALL || !TLI.getLibcallName(LC)) {
4297-
// We'll expand the multiplication by brute force because we have no other
4298-
// options. This is a trivially-generalized version of the code from
4299-
// Hacker's Delight (itself derived from Knuth's Algorithm M from section
4300-
// 4.3.1).
4301-
unsigned Bits = NVT.getSizeInBits();
4302-
unsigned HalfBits = Bits >> 1;
4303-
SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl,
4304-
NVT);
4305-
SDValue LLL = DAG.getNode(ISD::AND, dl, NVT, LL, Mask);
4306-
SDValue RLL = DAG.getNode(ISD::AND, dl, NVT, RL, Mask);
4307-
4308-
SDValue T = DAG.getNode(ISD::MUL, dl, NVT, LLL, RLL);
4309-
SDValue TL = DAG.getNode(ISD::AND, dl, NVT, T, Mask);
4310-
4311-
SDValue Shift = DAG.getShiftAmountConstant(HalfBits, NVT, dl);
4312-
SDValue TH = DAG.getNode(ISD::SRL, dl, NVT, T, Shift);
4313-
SDValue LLH = DAG.getNode(ISD::SRL, dl, NVT, LL, Shift);
4314-
SDValue RLH = DAG.getNode(ISD::SRL, dl, NVT, RL, Shift);
4315-
4316-
SDValue U = DAG.getNode(ISD::ADD, dl, NVT,
4317-
DAG.getNode(ISD::MUL, dl, NVT, LLH, RLL), TH);
4318-
SDValue UL = DAG.getNode(ISD::AND, dl, NVT, U, Mask);
4319-
SDValue UH = DAG.getNode(ISD::SRL, dl, NVT, U, Shift);
4320-
4321-
SDValue V = DAG.getNode(ISD::ADD, dl, NVT,
4322-
DAG.getNode(ISD::MUL, dl, NVT, LLL, RLH), UL);
4323-
SDValue VH = DAG.getNode(ISD::SRL, dl, NVT, V, Shift);
4324-
4325-
SDValue W = DAG.getNode(ISD::ADD, dl, NVT,
4326-
DAG.getNode(ISD::MUL, dl, NVT, LLH, RLH),
4327-
DAG.getNode(ISD::ADD, dl, NVT, UH, VH));
4328-
Lo = DAG.getNode(ISD::ADD, dl, NVT, TL,
4329-
DAG.getNode(ISD::SHL, dl, NVT, V, Shift));
4330-
4331-
Hi = DAG.getNode(ISD::ADD, dl, NVT, W,
4332-
DAG.getNode(ISD::ADD, dl, NVT,
4333-
DAG.getNode(ISD::MUL, dl, NVT, RH, LL),
4334-
DAG.getNode(ISD::MUL, dl, NVT, RL, LH)));
4297+
TLI.forceExpandMUL(DAG, dl, /*Signed=*/false, Lo, Hi, LL, RL, LH, RH);
43354298
return;
43364299
}
43374300

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10857,6 +10857,64 @@ SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
1085710857
return DAG.getSelect(dl, VT, Cond, SatVal, Result);
1085810858
}
1085910859

10860+
void TargetLowering::forceExpandMUL(SelectionDAG &DAG, const SDLoc &dl,
10861+
bool Signed, SDValue &Lo, SDValue &Hi,
10862+
SDValue LHS, SDValue RHS, SDValue HiLHS,
10863+
SDValue HiRHS) const {
10864+
EVT VT = LHS.getValueType();
10865+
assert(RHS.getValueType() == VT && "Mismatching operand types");
10866+
10867+
assert((HiLHS && HiRHS) || (!HiLHS && !HiRHS));
10868+
assert((!Signed || !HiLHS) &&
10869+
"Signed flag should only be set when HiLHS and RiRHS are null");
10870+
10871+
// We'll expand the multiplication by brute force because we have no other
10872+
// options. This is a trivially-generalized version of the code from
10873+
// Hacker's Delight (itself derived from Knuth's Algorithm M from section
10874+
// 4.3.1). If Signed is set, we can use arithmetic right shifts to propagate
10875+
// sign bits while calculating the Hi half.
10876+
unsigned Bits = VT.getSizeInBits();
10877+
unsigned HalfBits = Bits / 2;
10878+
SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
10879+
SDValue LL = DAG.getNode(ISD::AND, dl, VT, LHS, Mask);
10880+
SDValue RL = DAG.getNode(ISD::AND, dl, VT, RHS, Mask);
10881+
10882+
SDValue T = DAG.getNode(ISD::MUL, dl, VT, LL, RL);
10883+
SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
10884+
10885+
SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
10886+
// This is always an unsigned shift.
10887+
SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
10888+
10889+
unsigned ShiftOpc = Signed ? ISD::SRA : ISD::SRL;
10890+
SDValue LH = DAG.getNode(ShiftOpc, dl, VT, LHS, Shift);
10891+
SDValue RH = DAG.getNode(ShiftOpc, dl, VT, RHS, Shift);
10892+
10893+
SDValue U =
10894+
DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RL), TH);
10895+
SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
10896+
SDValue UH = DAG.getNode(ShiftOpc, dl, VT, U, Shift);
10897+
10898+
SDValue V =
10899+
DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LL, RH), UL);
10900+
SDValue VH = DAG.getNode(ShiftOpc, dl, VT, V, Shift);
10901+
10902+
Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
10903+
DAG.getNode(ISD::SHL, dl, VT, V, Shift));
10904+
10905+
Hi = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RH),
10906+
DAG.getNode(ISD::ADD, dl, VT, UH, VH));
10907+
10908+
// If HiLHS and HiRHS are set, multiply them by the opposite low part and add
10909+
// them to products to Hi.
10910+
if (HiLHS) {
10911+
Hi = DAG.getNode(ISD::ADD, dl, VT, Hi,
10912+
DAG.getNode(ISD::ADD, dl, VT,
10913+
DAG.getNode(ISD::MUL, dl, VT, HiRHS, LHS),
10914+
DAG.getNode(ISD::MUL, dl, VT, RHS, HiLHS)));
10915+
}
10916+
}
10917+
1086010918
void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
1086110919
bool Signed, const SDValue LHS,
1086210920
const SDValue RHS, SDValue &Lo,
@@ -10876,7 +10934,11 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
1087610934
else if (WideVT == MVT::i128)
1087710935
LC = RTLIB::MUL_I128;
1087810936

10879-
if (LC != RTLIB::UNKNOWN_LIBCALL && getLibcallName(LC)) {
10937+
if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
10938+
forceExpandMUL(DAG, dl, Signed, Lo, Hi, LHS, RHS);
10939+
return;
10940+
}
10941+
1088010942
SDValue HiLHS, HiRHS;
1088110943
if (Signed) {
1088210944
// The high part is obtained by SRA'ing all but one of the bits of low
@@ -10916,44 +10978,6 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
1091610978
Lo = Ret.getOperand(1);
1091710979
Hi = Ret.getOperand(0);
1091810980
}
10919-
return;
10920-
}
10921-
10922-
// Expand the multiplication by brute force. This is a generalized-version of
10923-
// the code from Hacker's Delight (itself derived from Knuth's Algorithm M
10924-
// from section 4.3.1) combined with the Hacker's delight code
10925-
// for calculating mulhs.
10926-
unsigned Bits = VT.getSizeInBits();
10927-
unsigned HalfBits = Bits / 2;
10928-
SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
10929-
SDValue LL = DAG.getNode(ISD::AND, dl, VT, LHS, Mask);
10930-
SDValue RL = DAG.getNode(ISD::AND, dl, VT, RHS, Mask);
10931-
10932-
SDValue T = DAG.getNode(ISD::MUL, dl, VT, LL, RL);
10933-
SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
10934-
10935-
SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
10936-
// This is always an unsigned shift.
10937-
SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
10938-
10939-
unsigned ShiftOpc = Signed ? ISD::SRA : ISD::SRL;
10940-
SDValue LH = DAG.getNode(ShiftOpc, dl, VT, LHS, Shift);
10941-
SDValue RH = DAG.getNode(ShiftOpc, dl, VT, RHS, Shift);
10942-
10943-
SDValue U =
10944-
DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RL), TH);
10945-
SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
10946-
SDValue UH = DAG.getNode(ShiftOpc, dl, VT, U, Shift);
10947-
10948-
SDValue V =
10949-
DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LL, RH), UL);
10950-
SDValue VH = DAG.getNode(ShiftOpc, dl, VT, V, Shift);
10951-
10952-
Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
10953-
DAG.getNode(ISD::SHL, dl, VT, V, Shift));
10954-
10955-
Hi = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RH),
10956-
DAG.getNode(ISD::ADD, dl, VT, UH, VH));
1095710981
}
1095810982

1095910983
SDValue

0 commit comments

Comments
 (0)