Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions llvm/lib/Target/AArch64/AArch64Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,19 @@ def mul_const : GICombineRule<
(apply [{ applyAArch64MulConstCombine(*${root}, MRI, B, ${matchinfo}); }])
>;

def lower_mull : GICombineRule<
(defs root:$root),
(match (wip_match_opcode G_MUL):$root,
[{ return matchExtMulToMULL(*${root}, MRI); }]),
(apply [{ applyExtMulToMULL(*${root}, MRI, B, Observer); }])
def mull_matchdata : GIDefMatchData<"std::tuple<bool, Register, Register>">;
def extmultomull : GICombineRule<
(defs root:$root, mull_matchdata:$matchinfo),
(match (G_MUL $dst, $src1, $src2):$root,
[{ return matchExtMulToMULL(*${root}, MRI, VT, ${matchinfo}); }]),
(apply [{ applyExtMulToMULL(*${root}, MRI, B, Observer, ${matchinfo}); }])
>;

def lower_mulv2s64 : GICombineRule<
(defs root:$root, mull_matchdata:$matchinfo),
(match (G_MUL $dst, $src1, $src2):$root,
[{ return matchMulv2s64(*${root}, MRI); }]),
(apply [{ applyMulv2s64(*${root}, MRI, B, Observer); }])
>;

def build_vector_to_dup : GICombineRule<
Expand Down Expand Up @@ -316,7 +324,7 @@ def AArch64PostLegalizerLowering
icmp_lowering, build_vector_lowering,
lower_vector_fcmp, form_truncstore,
vector_sext_inreg_to_shift,
unmerge_ext_to_unmerge, lower_mull,
unmerge_ext_to_unmerge, lower_mulv2s64,
vector_unmerge_lowering, insertelt_nonconst]> {
}

Expand All @@ -339,5 +347,5 @@ def AArch64PostLegalizerCombiner
select_to_minmax, or_to_bsp, combine_concat_vector,
commute_constant_to_rhs,
push_freeze_to_prevent_poison_from_propagating,
combine_mul_cmlt, combine_use_vector_truncate]> {
combine_mul_cmlt, combine_use_vector_truncate, extmultomull]> {
}
116 changes: 116 additions & 0 deletions llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,122 @@ void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
MI.eraseFromParent();
}

// Match mul({z/s}ext , {z/s}ext) => {u/s}mull
bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
GISelValueTracking *KB,
std::tuple<bool, Register, Register> &MatchInfo) {
// Get the instructions that defined the source operand
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
unsigned I1Opc = I1->getOpcode();
unsigned I2Opc = I2->getOpcode();
unsigned EltSize = DstTy.getScalarSizeInBits();

if (!DstTy.isVector() || I1->getNumOperands() < 2 || I2->getNumOperands() < 2)
return false;

auto IsAtLeastDoubleExtend = [&](Register R) {
LLT Ty = MRI.getType(R);
return EltSize >= Ty.getScalarSizeInBits() * 2;
};

// If the source operands were EXTENDED before, then {U/S}MULL can be used
bool IsZExt1 =
I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_ANYEXT;
bool IsZExt2 =
I2Opc == TargetOpcode::G_ZEXT || I2Opc == TargetOpcode::G_ANYEXT;
if (IsZExt1 && IsZExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
get<0>(MatchInfo) = true;
get<1>(MatchInfo) = I1->getOperand(1).getReg();
get<2>(MatchInfo) = I2->getOperand(1).getReg();
return true;
}

bool IsSExt1 =
I1Opc == TargetOpcode::G_SEXT || I1Opc == TargetOpcode::G_ANYEXT;
bool IsSExt2 =
I2Opc == TargetOpcode::G_SEXT || I2Opc == TargetOpcode::G_ANYEXT;
if (IsSExt1 && IsSExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
get<0>(MatchInfo) = false;
get<1>(MatchInfo) = I1->getOperand(1).getReg();
get<2>(MatchInfo) = I2->getOperand(1).getReg();
return true;
}

// Select UMULL if we can replace the other operand with an extend.
APInt Mask = APInt::getHighBitsSet(EltSize, EltSize / 2);
if (KB && (IsZExt1 || IsZExt2) &&
IsAtLeastDoubleExtend(IsZExt1 ? I1->getOperand(1).getReg()
: I2->getOperand(1).getReg())) {
Register ZExtOp =
IsZExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
if (KB->maskedValueIsZero(ZExtOp, Mask)) {
get<0>(MatchInfo) = true;
get<1>(MatchInfo) = IsZExt1 ? I1->getOperand(1).getReg() : ZExtOp;
get<2>(MatchInfo) = IsZExt1 ? ZExtOp : I2->getOperand(1).getReg();
return true;
}
} else if (KB && DstTy == LLT::fixed_vector(2, 64) &&
KB->maskedValueIsZero(MI.getOperand(1).getReg(), Mask) &&
KB->maskedValueIsZero(MI.getOperand(2).getReg(), Mask)) {
get<0>(MatchInfo) = true;
get<1>(MatchInfo) = MI.getOperand(1).getReg();
get<2>(MatchInfo) = MI.getOperand(2).getReg();
return true;
}

if (KB && (IsSExt1 || IsSExt2) &&
IsAtLeastDoubleExtend(IsSExt1 ? I1->getOperand(1).getReg()
: I2->getOperand(1).getReg())) {
Register SExtOp =
IsSExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
if (KB->computeNumSignBits(SExtOp) > EltSize / 2) {
get<0>(MatchInfo) = false;
get<1>(MatchInfo) = IsSExt1 ? I1->getOperand(1).getReg() : SExtOp;
get<2>(MatchInfo) = IsSExt1 ? SExtOp : I2->getOperand(1).getReg();
return true;
}
} else if (KB && DstTy == LLT::fixed_vector(2, 64) &&
KB->computeNumSignBits(MI.getOperand(1).getReg()) > EltSize / 2 &&
KB->computeNumSignBits(MI.getOperand(2).getReg()) > EltSize / 2) {
get<0>(MatchInfo) = false;
get<1>(MatchInfo) = MI.getOperand(1).getReg();
get<2>(MatchInfo) = MI.getOperand(2).getReg();
return true;
}

return false;
}

void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B, GISelChangeObserver &Observer,
std::tuple<bool, Register, Register> &MatchInfo) {
assert(MI.getOpcode() == TargetOpcode::G_MUL &&
"Expected a G_MUL instruction");

// Get the instructions that defined the source operand
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
bool IsZExt = get<0>(MatchInfo);
Register Src1Reg = get<1>(MatchInfo);
Register Src2Reg = get<2>(MatchInfo);
LLT Src1Ty = MRI.getType(Src1Reg);
LLT Src2Ty = MRI.getType(Src2Reg);
LLT HalfDstTy = DstTy.changeElementSize(DstTy.getScalarSizeInBits() / 2);
unsigned ExtOpc = IsZExt ? TargetOpcode::G_ZEXT : TargetOpcode::G_SEXT;

if (Src1Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
Src1Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src1Reg}).getReg(0);
if (Src2Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
Src2Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src2Reg}).getReg(0);

B.buildInstr(IsZExt ? AArch64::G_UMULL : AArch64::G_SMULL,
{MI.getOperand(0).getReg()}, {Src1Reg, Src2Reg});
MI.eraseFromParent();
}

class AArch64PostLegalizerCombinerImpl : public Combiner {
protected:
const CombinerHelper Helper;
Expand Down
62 changes: 9 additions & 53 deletions llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1190,68 +1190,24 @@ void applyUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
// Doing these two matches in one function to ensure that the order of matching
// will always be the same.
// Try lowering MUL to MULL before trying to scalarize if needed.
bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI) {
bool matchMulv2s64(MachineInstr &MI, MachineRegisterInfo &MRI) {
// Get the instructions that defined the source operand
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);

if (DstTy.isVector()) {
// If the source operands were EXTENDED before, then {U/S}MULL can be used
unsigned I1Opc = I1->getOpcode();
unsigned I2Opc = I2->getOpcode();
if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
(I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
(MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
(MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
return true;
}
// If result type is v2s64, scalarise the instruction
else if (DstTy == LLT::fixed_vector(2, 64)) {
return true;
}
}
return false;
return DstTy == LLT::fixed_vector(2, 64);
}

void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B, GISelChangeObserver &Observer) {
void applyMulv2s64(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B, GISelChangeObserver &Observer) {
assert(MI.getOpcode() == TargetOpcode::G_MUL &&
"Expected a G_MUL instruction");

// Get the instructions that defined the source operand
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);

// If the source operands were EXTENDED before, then {U/S}MULL can be used
unsigned I1Opc = I1->getOpcode();
unsigned I2Opc = I2->getOpcode();
if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
(I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
(MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
(MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {

B.setInstrAndDebugLoc(MI);
B.buildInstr(I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UMULL
: AArch64::G_SMULL,
{MI.getOperand(0).getReg()},
{I1->getOperand(1).getReg(), I2->getOperand(1).getReg()});
MI.eraseFromParent();
}
// If result type is v2s64, scalarise the instruction
else if (DstTy == LLT::fixed_vector(2, 64)) {
LegalizerHelper Helper(*MI.getMF(), Observer, B);
B.setInstrAndDebugLoc(MI);
Helper.fewerElementsVector(
MI, 0,
DstTy.changeElementCount(
DstTy.getElementCount().divideCoefficientBy(2)));
}
assert(DstTy == LLT::fixed_vector(2, 64) && "Expected v2s64 Mul");
LegalizerHelper Helper(*MI.getMF(), Observer, B);
Helper.fewerElementsVector(
MI, 0,
DstTy.changeElementCount(DstTy.getElementCount().divideCoefficientBy(2)));
}

class AArch64PostLegalizerLoweringImpl : public Combiner {
Expand Down
Loading