Skip to content

Commit 4938d16

Browse files
committed
[AArch64][GlobalISel] Improve MULL generation
This splits the existing post-legalize lowering of vector umull/smull into two parts - one to perform the optimization of mul(ext,ext) -> mull and one to perform the v2i64 mul scalarization. The mull part is moved to post legalizer combine and has been taught a few extra tricks from SDAG, using known bits to convert mul(sext, zext) or mul(zext, zero-upper-bits) into umull. This can be important to prevent v2i64 scalarization of muls.
1 parent 708b154 commit 4938d16

File tree

5 files changed

+201
-295
lines changed

5 files changed

+201
-295
lines changed

llvm/lib/Target/AArch64/AArch64Combine.td

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,19 @@ def mul_const : GICombineRule<
208208
(apply [{ applyAArch64MulConstCombine(*${root}, MRI, B, ${matchinfo}); }])
209209
>;
210210

211-
def lower_mull : GICombineRule<
212-
(defs root:$root),
211+
def mull_matchdata : GIDefMatchData<"std::tuple<bool, Register, Register>">;
212+
def extmultomull : GICombineRule<
213+
(defs root:$root, mull_matchdata:$matchinfo),
214+
(match (wip_match_opcode G_MUL):$root,
215+
[{ return matchExtMulToMULL(*${root}, MRI, KB, ${matchinfo}); }]),
216+
(apply [{ applyExtMulToMULL(*${root}, MRI, B, Observer, ${matchinfo}); }])
217+
>;
218+
219+
def lower_mulv2s64 : GICombineRule<
220+
(defs root:$root, mull_matchdata:$matchinfo),
213221
(match (wip_match_opcode G_MUL):$root,
214-
[{ return matchExtMulToMULL(*${root}, MRI); }]),
215-
(apply [{ applyExtMulToMULL(*${root}, MRI, B, Observer); }])
222+
[{ return matchMulv2s64(*${root}, MRI); }]),
223+
(apply [{ applyMulv2s64(*${root}, MRI, B, Observer); }])
216224
>;
217225

218226
def build_vector_to_dup : GICombineRule<
@@ -307,7 +315,7 @@ def AArch64PostLegalizerLowering
307315
icmp_lowering, build_vector_lowering,
308316
lower_vector_fcmp, form_truncstore,
309317
vector_sext_inreg_to_shift,
310-
unmerge_ext_to_unmerge, lower_mull,
318+
unmerge_ext_to_unmerge, lower_mulv2s64,
311319
vector_unmerge_lowering, insertelt_nonconst]> {
312320
}
313321

@@ -330,5 +338,5 @@ def AArch64PostLegalizerCombiner
330338
select_to_minmax, or_to_bsp, combine_concat_vector,
331339
commute_constant_to_rhs,
332340
push_freeze_to_prevent_poison_from_propagating,
333-
combine_mul_cmlt, combine_use_vector_truncate]> {
341+
combine_mul_cmlt, combine_use_vector_truncate, extmultomull]> {
334342
}

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,109 @@ void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
438438
MI.eraseFromParent();
439439
}
440440

441+
// Match mul({z/s}ext , {z/s}ext) => {u/s}mull
442+
bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
443+
GISelKnownBits *KB,
444+
std::tuple<bool, Register, Register> &MatchInfo) {
445+
// Get the instructions that defined the source operand
446+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
447+
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
448+
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
449+
unsigned I1Opc = I1->getOpcode();
450+
unsigned I2Opc = I2->getOpcode();
451+
452+
if (!DstTy.isVector() || I1->getNumOperands() < 2 || I2->getNumOperands() < 2)
453+
return false;
454+
455+
auto IsAtLeastDoubleExtend = [&](Register R) {
456+
LLT Ty = MRI.getType(R);
457+
return DstTy.getScalarSizeInBits() >= Ty.getScalarSizeInBits() * 2;
458+
};
459+
460+
// If the source operands were EXTENDED before, then {U/S}MULL can be used
461+
bool IsZExt1 =
462+
I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_ANYEXT;
463+
bool IsZExt2 =
464+
I2Opc == TargetOpcode::G_ZEXT || I2Opc == TargetOpcode::G_ANYEXT;
465+
if (IsZExt1 && IsZExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
466+
IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
467+
get<0>(MatchInfo) = true;
468+
get<1>(MatchInfo) = I1->getOperand(1).getReg();
469+
get<2>(MatchInfo) = I2->getOperand(1).getReg();
470+
return true;
471+
}
472+
473+
bool IsSExt1 =
474+
I1Opc == TargetOpcode::G_SEXT || I1Opc == TargetOpcode::G_ANYEXT;
475+
bool IsSExt2 =
476+
I2Opc == TargetOpcode::G_SEXT || I2Opc == TargetOpcode::G_ANYEXT;
477+
if (IsSExt1 && IsSExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
478+
IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
479+
get<0>(MatchInfo) = false;
480+
get<1>(MatchInfo) = I1->getOperand(1).getReg();
481+
get<2>(MatchInfo) = I2->getOperand(1).getReg();
482+
return true;
483+
}
484+
485+
// Select SMULL if we can replace zext with sext.
486+
if (KB && ((IsSExt1 && IsZExt2) || (IsZExt1 && IsSExt2)) &&
487+
IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
488+
IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
489+
Register ZExtOp =
490+
IsZExt1 ? I1->getOperand(1).getReg() : I2->getOperand(1).getReg();
491+
if (KB->signBitIsZero(ZExtOp)) {
492+
get<0>(MatchInfo) = false;
493+
get<1>(MatchInfo) = I1->getOperand(1).getReg();
494+
get<2>(MatchInfo) = I2->getOperand(1).getReg();
495+
return true;
496+
}
497+
}
498+
499+
// Select UMULL if we can replace the other operand with an extend.
500+
if (KB && (IsZExt1 || IsZExt2) &&
501+
IsAtLeastDoubleExtend(IsZExt1 ? I1->getOperand(1).getReg()
502+
: I2->getOperand(1).getReg())) {
503+
APInt Mask = APInt::getHighBitsSet(DstTy.getScalarSizeInBits(),
504+
DstTy.getScalarSizeInBits() / 2);
505+
Register ZExtOp =
506+
IsZExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
507+
if (KB->maskedValueIsZero(ZExtOp, Mask)) {
508+
get<0>(MatchInfo) = true;
509+
get<1>(MatchInfo) = IsZExt1 ? I1->getOperand(1).getReg() : ZExtOp;
510+
get<2>(MatchInfo) = IsZExt1 ? ZExtOp : I2->getOperand(1).getReg();
511+
return true;
512+
}
513+
}
514+
return false;
515+
}
516+
517+
void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
518+
MachineIRBuilder &B, GISelChangeObserver &Observer,
519+
std::tuple<bool, Register, Register> &MatchInfo) {
520+
assert(MI.getOpcode() == TargetOpcode::G_MUL &&
521+
"Expected a G_MUL instruction");
522+
523+
// Get the instructions that defined the source operand
524+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
525+
bool IsZExt = get<0>(MatchInfo);
526+
Register Src1Reg = get<1>(MatchInfo);
527+
Register Src2Reg = get<2>(MatchInfo);
528+
LLT Src1Ty = MRI.getType(Src1Reg);
529+
LLT Src2Ty = MRI.getType(Src2Reg);
530+
LLT HalfDstTy = DstTy.changeElementSize(DstTy.getScalarSizeInBits() / 2);
531+
unsigned ExtOpc = IsZExt ? TargetOpcode::G_ZEXT : TargetOpcode::G_SEXT;
532+
533+
if (Src1Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
534+
Src1Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src1Reg}).getReg(0);
535+
if (Src2Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
536+
Src2Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src2Reg}).getReg(0);
537+
538+
B.setInstrAndDebugLoc(MI);
539+
B.buildInstr(IsZExt ? AArch64::G_UMULL : AArch64::G_SMULL,
540+
{MI.getOperand(0).getReg()}, {Src1Reg, Src2Reg});
541+
MI.eraseFromParent();
542+
}
543+
441544
class AArch64PostLegalizerCombinerImpl : public Combiner {
442545
protected:
443546
// TODO: Make CombinerHelper methods const.

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp

Lines changed: 10 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,68 +1177,25 @@ void applyUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
11771177
// Doing these two matches in one function to ensure that the order of matching
11781178
// will always be the same.
11791179
// Try lowering MUL to MULL before trying to scalarize if needed.
1180-
bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI) {
1180+
bool matchMulv2s64(MachineInstr &MI, MachineRegisterInfo &MRI) {
11811181
// Get the instructions that defined the source operand
11821182
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1183-
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
1184-
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
1185-
1186-
if (DstTy.isVector()) {
1187-
// If the source operands were EXTENDED before, then {U/S}MULL can be used
1188-
unsigned I1Opc = I1->getOpcode();
1189-
unsigned I2Opc = I2->getOpcode();
1190-
if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
1191-
(I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
1192-
(MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
1193-
MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
1194-
(MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
1195-
MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
1196-
return true;
1197-
}
1198-
// If result type is v2s64, scalarise the instruction
1199-
else if (DstTy == LLT::fixed_vector(2, 64)) {
1200-
return true;
1201-
}
1202-
}
1203-
return false;
1183+
return DstTy == LLT::fixed_vector(2, 64);
12041184
}
12051185

1206-
void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
1207-
MachineIRBuilder &B, GISelChangeObserver &Observer) {
1186+
void applyMulv2s64(MachineInstr &MI, MachineRegisterInfo &MRI,
1187+
MachineIRBuilder &B, GISelChangeObserver &Observer) {
12081188
assert(MI.getOpcode() == TargetOpcode::G_MUL &&
12091189
"Expected a G_MUL instruction");
12101190

12111191
// Get the instructions that defined the source operand
12121192
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1213-
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
1214-
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
1215-
1216-
// If the source operands were EXTENDED before, then {U/S}MULL can be used
1217-
unsigned I1Opc = I1->getOpcode();
1218-
unsigned I2Opc = I2->getOpcode();
1219-
if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
1220-
(I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
1221-
(MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
1222-
MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
1223-
(MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
1224-
MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
1225-
1226-
B.setInstrAndDebugLoc(MI);
1227-
B.buildInstr(I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UMULL
1228-
: AArch64::G_SMULL,
1229-
{MI.getOperand(0).getReg()},
1230-
{I1->getOperand(1).getReg(), I2->getOperand(1).getReg()});
1231-
MI.eraseFromParent();
1232-
}
1233-
// If result type is v2s64, scalarise the instruction
1234-
else if (DstTy == LLT::fixed_vector(2, 64)) {
1235-
LegalizerHelper Helper(*MI.getMF(), Observer, B);
1236-
B.setInstrAndDebugLoc(MI);
1237-
Helper.fewerElementsVector(
1238-
MI, 0,
1239-
DstTy.changeElementCount(
1240-
DstTy.getElementCount().divideCoefficientBy(2)));
1241-
}
1193+
assert(DstTy == LLT::fixed_vector(2, 64) && "Expected v2s64 Mul");
1194+
LegalizerHelper Helper(*MI.getMF(), Observer, B);
1195+
B.setInstrAndDebugLoc(MI);
1196+
Helper.fewerElementsVector(
1197+
MI, 0,
1198+
DstTy.changeElementCount(DstTy.getElementCount().divideCoefficientBy(2)));
12421199
}
12431200

12441201
class AArch64PostLegalizerLoweringImpl : public Combiner {

0 commit comments

Comments
 (0)