@@ -228,13 +228,13 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
228228 B.buildConstant (LLT::scalar (64 ), -static_cast <int64_t >(MinOffset)));
229229}
230230
231- // Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot (x, y))
232- // Or vecreduce_add(ext(mul(ext(x), ext(y)))) -> vecreduce_add(udot (x, y))
233- // Or vecreduce_add(ext(x)) -> vecreduce_add(udot (x, 1))
231+ // Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add([us]dot (x, y))
232+ // Or vecreduce_add(ext(mul(ext(x), ext(y)))) -> vecreduce_add([us]dot (x, y))
233+ // Or vecreduce_add(ext(x)) -> vecreduce_add([us]dot (x, 1))
234234// Similar to performVecReduceAddCombine in SelectionDAG
235- bool matchExtAddvToUdotAddv (MachineInstr &MI, MachineRegisterInfo &MRI,
236- const AArch64Subtarget &STI,
237- std::tuple<Register, Register, bool > &MatchInfo) {
235+ bool matchExtAddvToDotAddv (MachineInstr &MI, MachineRegisterInfo &MRI,
236+ const AArch64Subtarget &STI,
237+ std::tuple<Register, Register, bool > &MatchInfo) {
238238 assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
239239 " Expected a G_VECREDUCE_ADD instruction" );
240240 assert (STI.hasDotProd () && " Target should have Dot Product feature" );
@@ -247,8 +247,8 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
247247 if (DstTy.getScalarSizeInBits () != 32 || MidTy.getScalarSizeInBits () != 32 )
248248 return false ;
249249
250- // Detect mul(ext, ext) with symetric ext's. If I1Opc is G_ZEXT or G_SEXT then
251- // the ext's must match the same opcode. It is set to the ext opcode on
250+ // Detect mul(ext, ext) with symmetric ext's. If I1Opc is G_ZEXT or G_SEXT
251+ // then the ext's must match the same opcode. It is set to the ext opcode on
252252 // output.
253253 auto tryMatchingMulOfExt = [&MRI](MachineInstr *MI, Register &Out1,
254254 Register &Out2, unsigned &I1Opc) {
@@ -315,11 +315,11 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
315315 return true ;
316316}
317317
318- void applyExtAddvToUdotAddv (MachineInstr &MI, MachineRegisterInfo &MRI,
319- MachineIRBuilder &Builder,
320- GISelChangeObserver &Observer,
321- const AArch64Subtarget &STI,
322- std::tuple<Register, Register, bool > &MatchInfo) {
318+ void applyExtAddvToDotAddv (MachineInstr &MI, MachineRegisterInfo &MRI,
319+ MachineIRBuilder &Builder,
320+ GISelChangeObserver &Observer,
321+ const AArch64Subtarget &STI,
322+ std::tuple<Register, Register, bool > &MatchInfo) {
323323 assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
324324 " Expected a G_VECREDUCE_ADD instruction" );
325325 assert (STI.hasDotProd () && " Target should have Dot Product feature" );
@@ -581,14 +581,14 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
581581}
582582
583583// Pushes ADD/SUB/MUL through extend instructions to decrease the number of
584- // extend instruction at the end by allowing selection of {s|u}addl sooner i32
585- // add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
584+ // extend instruction at the end by allowing selection of {s|u}addl sooner
585+ // i32 add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
586586bool matchPushAddSubExt (MachineInstr &MI, MachineRegisterInfo &MRI,
587587 Register DstReg, Register SrcReg1, Register SrcReg2) {
588588 assert ((MI.getOpcode () == TargetOpcode::G_ADD ||
589589 MI.getOpcode () == TargetOpcode::G_SUB ||
590590 MI.getOpcode () == TargetOpcode::G_MUL) &&
591- " Expected a G_ADD or G_SUB instruction\n " );
591+ " Expected a G_ADD, G_SUB or G_MUL instruction\n " );
592592
593593 // Deal with vector types only
594594 LLT DstTy = MRI.getType (DstReg);
@@ -623,7 +623,8 @@ void applyPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
623623 // G_SUB has to sign-extend the result.
624624 // G_ADD needs to sext from sext and can sext or zext from zext, and G_MUL
625625 // needs to use the original opcode so the original opcode is used for both.
626- if (MI.getOpcode () != TargetOpcode::G_SUB)
626+ if (MI.getOpcode () == TargetOpcode::G_ADD ||
627+ MI.getOpcode () == TargetOpcode::G_MUL)
627628 B.buildInstr (Opc, {DstReg}, {AddReg});
628629 else
629630 B.buildSExt (DstReg, AddReg);
0 commit comments