Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions llvm/lib/Target/AArch64/AArch64Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ let Predicates = [HasDotProd] in {
def ext_addv_to_udot_addv : GICombineRule<
(defs root:$root, ext_addv_to_udot_addv_matchinfo:$matchinfo),
(match (wip_match_opcode G_VECREDUCE_ADD):$root,
[{ return matchExtAddvToUdotAddv(*${root}, MRI, STI, ${matchinfo}); }]),
(apply [{ applyExtAddvToUdotAddv(*${root}, MRI, B, Observer, STI, ${matchinfo}); }])
[{ return matchExtAddvToDotAddv(*${root}, MRI, STI, ${matchinfo}); }]),
(apply [{ applyExtAddvToDotAddv(*${root}, MRI, B, Observer, STI, ${matchinfo}); }])
>;
}

Expand All @@ -62,8 +62,10 @@ class push_opcode_through_ext<Instruction opcode, Instruction extOpcode> : GICom

def push_sub_through_zext : push_opcode_through_ext<G_SUB, G_ZEXT>;
def push_add_through_zext : push_opcode_through_ext<G_ADD, G_ZEXT>;
def push_mul_through_zext : push_opcode_through_ext<G_MUL, G_ZEXT>;
def push_sub_through_sext : push_opcode_through_ext<G_SUB, G_SEXT>;
def push_add_through_sext : push_opcode_through_ext<G_ADD, G_SEXT>;
def push_mul_through_sext : push_opcode_through_ext<G_MUL, G_SEXT>;

def AArch64PreLegalizerCombiner: GICombiner<
"AArch64PreLegalizerCombinerImpl", [all_combines,
Expand All @@ -75,8 +77,10 @@ def AArch64PreLegalizerCombiner: GICombiner<
ext_uaddv_to_uaddlv,
push_sub_through_zext,
push_add_through_zext,
push_mul_through_zext,
push_sub_through_sext,
push_add_through_sext]> {
push_add_through_sext,
push_mul_through_sext]> {
let CombineAllMethodName = "tryCombineAllImpl";
}

Expand Down
90 changes: 59 additions & 31 deletions llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,13 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
}

// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot(x, y))
// Or vecreduce_add(ext(x)) -> vecreduce_add(udot(x, 1))
// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add([us]dot(x, y))
// Or vecreduce_add(ext(mul(ext(x), ext(y)))) -> vecreduce_add([us]dot(x, y))
// Or vecreduce_add(ext(x)) -> vecreduce_add([us]dot(x, 1))
// Similar to performVecReduceAddCombine in SelectionDAG
bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
const AArch64Subtarget &STI,
std::tuple<Register, Register, bool> &MatchInfo) {
bool matchExtAddvToDotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
const AArch64Subtarget &STI,
std::tuple<Register, Register, bool> &MatchInfo) {
assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
"Expected a G_VECREDUCE_ADD instruction");
assert(STI.hasDotProd() && "Target should have Dot Product feature");
Expand All @@ -246,31 +247,57 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
return false;

LLT SrcTy;
auto I1Opc = I1->getOpcode();
if (I1Opc == TargetOpcode::G_MUL) {
// Detect mul(ext, ext) with symmetric ext's. If I1Opc is G_ZEXT or G_SEXT
// then the ext's must match the same opcode. It is set to the ext opcode on
// output.
auto tryMatchingMulOfExt = [&MRI](MachineInstr *MI, Register &Out1,
Register &Out2, unsigned &I1Opc) {
// If result of this has more than 1 use, then there is no point in creating
// udot instruction
if (!MRI.hasOneNonDBGUse(MidReg))
// a dot instruction
if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
return false;

MachineInstr *ExtMI1 =
getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
getDefIgnoringCopies(MI->getOperand(1).getReg(), MRI);
MachineInstr *ExtMI2 =
getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
getDefIgnoringCopies(MI->getOperand(2).getReg(), MRI);
LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());

if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
return false;
if ((I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) &&
I1Opc != ExtMI1->getOpcode())
return false;
Out1 = ExtMI1->getOperand(1).getReg();
Out2 = ExtMI2->getOperand(1).getReg();
I1Opc = ExtMI1->getOpcode();
SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
std::get<0>(MatchInfo) = ExtMI1->getOperand(1).getReg();
std::get<1>(MatchInfo) = ExtMI2->getOperand(1).getReg();
return true;
};

LLT SrcTy;
unsigned I1Opc = I1->getOpcode();
if (I1Opc == TargetOpcode::G_MUL) {
Register Out1, Out2;
if (!tryMatchingMulOfExt(I1, Out1, Out2, I1Opc))
return false;
SrcTy = MRI.getType(Out1);
std::get<0>(MatchInfo) = Out1;
std::get<1>(MatchInfo) = Out2;
} else if (I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) {
SrcTy = MRI.getType(I1->getOperand(1).getReg());
std::get<0>(MatchInfo) = I1->getOperand(1).getReg();
std::get<1>(MatchInfo) = 0;
Register I1Op = I1->getOperand(1).getReg();
MachineInstr *M = getDefIgnoringCopies(I1Op, MRI);
Register Out1, Out2;
if (M->getOpcode() == TargetOpcode::G_MUL &&
tryMatchingMulOfExt(M, Out1, Out2, I1Opc)) {
SrcTy = MRI.getType(Out1);
std::get<0>(MatchInfo) = Out1;
std::get<1>(MatchInfo) = Out2;
} else {
SrcTy = MRI.getType(I1Op);
std::get<0>(MatchInfo) = I1Op;
std::get<1>(MatchInfo) = 0;
}
} else {
return false;
}
Expand All @@ -288,11 +315,11 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
return true;
}

void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &Builder,
GISelChangeObserver &Observer,
const AArch64Subtarget &STI,
std::tuple<Register, Register, bool> &MatchInfo) {
void applyExtAddvToDotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &Builder,
GISelChangeObserver &Observer,
const AArch64Subtarget &STI,
std::tuple<Register, Register, bool> &MatchInfo) {
assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
"Expected a G_VECREDUCE_ADD instruction");
assert(STI.hasDotProd() && "Target should have Dot Product feature");
Expand Down Expand Up @@ -553,15 +580,15 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
MI.eraseFromParent();
}

// Pushes ADD/SUB through extend instructions to decrease the number of extend
// instruction at the end by allowing selection of {s|u}addl sooner

// Pushes ADD/SUB/MUL through extend instructions to decrease the number of
// extend instruction at the end by allowing selection of {s|u}addl sooner
// i32 add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
bool matchPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
Register DstReg, Register SrcReg1, Register SrcReg2) {
assert((MI.getOpcode() == TargetOpcode::G_ADD ||
MI.getOpcode() == TargetOpcode::G_SUB) &&
"Expected a G_ADD or G_SUB instruction\n");
MI.getOpcode() == TargetOpcode::G_SUB ||
MI.getOpcode() == TargetOpcode::G_MUL) &&
"Expected a G_ADD, G_SUB or G_MUL instruction\n");

// Deal with vector types only
LLT DstTy = MRI.getType(DstReg);
Expand Down Expand Up @@ -594,9 +621,10 @@ void applyPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
B.buildInstr(MI.getOpcode(), {MidTy}, {Ext1Reg, Ext2Reg}).getReg(0);

// G_SUB has to sign-extend the result.
// G_ADD needs to sext from sext and can sext or zext from zext, so the
// original opcode is used.
if (MI.getOpcode() == TargetOpcode::G_ADD)
// G_ADD needs to sext from sext and can sext or zext from zext, and G_MUL
// needs to use the original opcode so the original opcode is used for both.
if (MI.getOpcode() == TargetOpcode::G_ADD ||
MI.getOpcode() == TargetOpcode::G_MUL)
B.buildInstr(Opc, {DstReg}, {AddReg});
else
B.buildSExt(DstReg, AddReg);
Expand Down
110 changes: 41 additions & 69 deletions llvm/test/CodeGen/AArch64/aarch64-wide-mul.ll
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ define <16 x i32> @mul_i32(<16 x i8> %a, <16 x i8> %b) {
;
; CHECK-GI-LABEL: mul_i32:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-GI-NEXT: ushll v3.8h, v1.8b, #0
; CHECK-GI-NEXT: ushll2 v4.8h, v0.16b, #0
; CHECK-GI-NEXT: ushll2 v5.8h, v1.16b, #0
; CHECK-GI-NEXT: umull v0.4s, v2.4h, v3.4h
; CHECK-GI-NEXT: umull2 v1.4s, v2.8h, v3.8h
; CHECK-GI-NEXT: umull v2.4s, v4.4h, v5.4h
; CHECK-GI-NEXT: umull2 v3.4s, v4.8h, v5.8h
; CHECK-GI-NEXT: umull v2.8h, v0.8b, v1.8b
; CHECK-GI-NEXT: umull2 v3.8h, v0.16b, v1.16b
; CHECK-GI-NEXT: ushll v0.4s, v2.4h, #0
; CHECK-GI-NEXT: ushll2 v1.4s, v2.8h, #0
; CHECK-GI-NEXT: ushll v2.4s, v3.4h, #0
; CHECK-GI-NEXT: ushll2 v3.4s, v3.8h, #0
; CHECK-GI-NEXT: ret
entry:
%ea = zext <16 x i8> %a to <16 x i32>
Expand Down Expand Up @@ -75,26 +73,20 @@ define <16 x i64> @mul_i64(<16 x i8> %a, <16 x i8> %b) {
;
; CHECK-GI-LABEL: mul_i64:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-GI-NEXT: ushll v3.8h, v1.8b, #0
; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-GI-NEXT: ushll v4.4s, v2.4h, #0
; CHECK-GI-NEXT: ushll2 v5.4s, v2.8h, #0
; CHECK-GI-NEXT: ushll v2.4s, v3.4h, #0
; CHECK-GI-NEXT: ushll v6.4s, v0.4h, #0
; CHECK-GI-NEXT: ushll2 v3.4s, v3.8h, #0
; CHECK-GI-NEXT: ushll v7.4s, v1.4h, #0
; CHECK-GI-NEXT: ushll2 v16.4s, v0.8h, #0
; CHECK-GI-NEXT: ushll2 v17.4s, v1.8h, #0
; CHECK-GI-NEXT: umull v0.2d, v4.2s, v2.2s
; CHECK-GI-NEXT: umull2 v1.2d, v4.4s, v2.4s
; CHECK-GI-NEXT: umull v2.2d, v5.2s, v3.2s
; CHECK-GI-NEXT: umull2 v3.2d, v5.4s, v3.4s
; CHECK-GI-NEXT: umull v4.2d, v6.2s, v7.2s
; CHECK-GI-NEXT: umull2 v5.2d, v6.4s, v7.4s
; CHECK-GI-NEXT: umull v6.2d, v16.2s, v17.2s
; CHECK-GI-NEXT: umull2 v7.2d, v16.4s, v17.4s
; CHECK-GI-NEXT: umull v2.8h, v0.8b, v1.8b
; CHECK-GI-NEXT: umull2 v0.8h, v0.16b, v1.16b
; CHECK-GI-NEXT: ushll v1.4s, v2.4h, #0
; CHECK-GI-NEXT: ushll2 v3.4s, v2.8h, #0
; CHECK-GI-NEXT: ushll v5.4s, v0.4h, #0
; CHECK-GI-NEXT: ushll2 v7.4s, v0.8h, #0
; CHECK-GI-NEXT: ushll v0.2d, v1.2s, #0
; CHECK-GI-NEXT: ushll2 v1.2d, v1.4s, #0
; CHECK-GI-NEXT: ushll v2.2d, v3.2s, #0
; CHECK-GI-NEXT: ushll2 v3.2d, v3.4s, #0
; CHECK-GI-NEXT: ushll v4.2d, v5.2s, #0
; CHECK-GI-NEXT: ushll2 v5.2d, v5.4s, #0
; CHECK-GI-NEXT: ushll v6.2d, v7.2s, #0
; CHECK-GI-NEXT: ushll2 v7.2d, v7.4s, #0
; CHECK-GI-NEXT: ret
entry:
%ea = zext <16 x i8> %a to <16 x i64>
Expand Down Expand Up @@ -142,18 +134,12 @@ define <16 x i32> @mla_i32(<16 x i8> %a, <16 x i8> %b, <16 x i32> %c) {
;
; CHECK-GI-LABEL: mla_i32:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: ushll v6.8h, v0.8b, #0
; CHECK-GI-NEXT: ushll v7.8h, v1.8b, #0
; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-GI-NEXT: umlal v2.4s, v6.4h, v7.4h
; CHECK-GI-NEXT: umlal2 v3.4s, v6.8h, v7.8h
; CHECK-GI-NEXT: umlal v4.4s, v0.4h, v1.4h
; CHECK-GI-NEXT: umlal2 v5.4s, v0.8h, v1.8h
; CHECK-GI-NEXT: mov v0.16b, v2.16b
; CHECK-GI-NEXT: mov v1.16b, v3.16b
; CHECK-GI-NEXT: mov v2.16b, v4.16b
; CHECK-GI-NEXT: mov v3.16b, v5.16b
; CHECK-GI-NEXT: umull v6.8h, v0.8b, v1.8b
; CHECK-GI-NEXT: umull2 v7.8h, v0.16b, v1.16b
; CHECK-GI-NEXT: uaddw v0.4s, v2.4s, v6.4h
; CHECK-GI-NEXT: uaddw2 v1.4s, v3.4s, v6.8h
; CHECK-GI-NEXT: uaddw v2.4s, v4.4s, v7.4h
; CHECK-GI-NEXT: uaddw2 v3.4s, v5.4s, v7.8h
; CHECK-GI-NEXT: ret
entry:
%ea = zext <16 x i8> %a to <16 x i32>
Expand Down Expand Up @@ -186,35 +172,21 @@ define <16 x i64> @mla_i64(<16 x i8> %a, <16 x i8> %b, <16 x i64> %c) {
;
; CHECK-GI-LABEL: mla_i64:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: mov v16.16b, v2.16b
; CHECK-GI-NEXT: mov v17.16b, v3.16b
; CHECK-GI-NEXT: mov v2.16b, v4.16b
; CHECK-GI-NEXT: mov v3.16b, v5.16b
; CHECK-GI-NEXT: mov v4.16b, v6.16b
; CHECK-GI-NEXT: mov v5.16b, v7.16b
; CHECK-GI-NEXT: ushll v6.8h, v0.8b, #0
; CHECK-GI-NEXT: ushll v7.8h, v1.8b, #0
; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-GI-NEXT: ushll v18.4s, v6.4h, #0
; CHECK-GI-NEXT: ushll v20.4s, v7.4h, #0
; CHECK-GI-NEXT: ushll2 v19.4s, v6.8h, #0
; CHECK-GI-NEXT: ushll v21.4s, v0.4h, #0
; CHECK-GI-NEXT: ushll2 v22.4s, v7.8h, #0
; CHECK-GI-NEXT: ushll v23.4s, v1.4h, #0
; CHECK-GI-NEXT: ldp q6, q7, [sp]
; CHECK-GI-NEXT: ushll2 v0.4s, v0.8h, #0
; CHECK-GI-NEXT: ushll2 v1.4s, v1.8h, #0
; CHECK-GI-NEXT: umlal v16.2d, v18.2s, v20.2s
; CHECK-GI-NEXT: umlal2 v17.2d, v18.4s, v20.4s
; CHECK-GI-NEXT: umlal v2.2d, v19.2s, v22.2s
; CHECK-GI-NEXT: umlal2 v3.2d, v19.4s, v22.4s
; CHECK-GI-NEXT: umlal v4.2d, v21.2s, v23.2s
; CHECK-GI-NEXT: umlal2 v5.2d, v21.4s, v23.4s
; CHECK-GI-NEXT: umlal v6.2d, v0.2s, v1.2s
; CHECK-GI-NEXT: umlal2 v7.2d, v0.4s, v1.4s
; CHECK-GI-NEXT: mov v0.16b, v16.16b
; CHECK-GI-NEXT: mov v1.16b, v17.16b
; CHECK-GI-NEXT: umull v16.8h, v0.8b, v1.8b
; CHECK-GI-NEXT: umull2 v0.8h, v0.16b, v1.16b
; CHECK-GI-NEXT: ldp q19, q20, [sp]
; CHECK-GI-NEXT: ushll v1.4s, v16.4h, #0
; CHECK-GI-NEXT: ushll2 v16.4s, v16.8h, #0
; CHECK-GI-NEXT: ushll v17.4s, v0.4h, #0
; CHECK-GI-NEXT: ushll2 v18.4s, v0.8h, #0
; CHECK-GI-NEXT: uaddw v0.2d, v2.2d, v1.2s
; CHECK-GI-NEXT: uaddw2 v1.2d, v3.2d, v1.4s
; CHECK-GI-NEXT: uaddw v2.2d, v4.2d, v16.2s
; CHECK-GI-NEXT: uaddw2 v3.2d, v5.2d, v16.4s
; CHECK-GI-NEXT: uaddw v4.2d, v6.2d, v17.2s
; CHECK-GI-NEXT: uaddw2 v5.2d, v7.2d, v17.4s
; CHECK-GI-NEXT: uaddw v6.2d, v19.2d, v18.2s
; CHECK-GI-NEXT: uaddw2 v7.2d, v20.2d, v18.4s
; CHECK-GI-NEXT: ret
entry:
%ea = zext <16 x i8> %a to <16 x i64>
Expand Down
Loading