Skip to content

Commit 2ad6cc8

Browse files
committed
[RISCV][TII] Add and use new hook fo optmize/canonicalize instructions after MachineCopyPropagation
PR llvm#136875 was posted as a draft PR that handled a subset of these cases, using the CompressPat mechanism. The consensus from that discussion (and a conclusion I agree with) is that it would be beneficial doing this optimisation earlier on, and in a way that isn't limited just to cases that can be handled by instruction compression. The most common source for instructions that can be optimized/canonicalized in this way is through tail duplication followed by machine copy propagation. For RISC-V, choosing a more canonical instruction allows it to be compressed when it couldn't be before. There is the potential that it would make other MI-level optimisations easier. This modifies ~910 instructions across an llvm-test-suite build including SPEC2017, targeting rva22u64. Coverage of instructions is based on observations from a script written to find redundant or improperly canonicalized instructions (though I aim to support all instructions in a 'group' at once, e.g. MUL* even if I only saw some variants of MUL in practice).
1 parent 2e713af commit 2ad6cc8

File tree

5 files changed

+921
-0
lines changed

5 files changed

+921
-0
lines changed

llvm/include/llvm/CodeGen/TargetInstrInfo.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,16 @@ class TargetInstrInfo : public MCInstrInfo {
510510
return false;
511511
}
512512

513+
/// If possible, converts the instruction to a more 'optimized'/canonical
514+
/// form. Returns true if the instruction was modified.
515+
///
516+
/// This function is only called after register allocation. The MI will be
517+
/// modified in place. This is called by passes such as
518+
/// MachineCopyPropagation, where their mutation of the MI operands may
519+
/// expose opportunities to convert the instruction to a simpler form (e.g.
520+
/// a load of 0).
521+
virtual bool optimizeInstruction(MachineInstr &MI) const { return false; }
522+
513523
/// A pair composed of a register and a sub-register index.
514524
/// Used to give some type checking when modeling Reg:SubReg.
515525
struct RegSubRegPair {

llvm/lib/CodeGen/MachineCopyPropagation.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,12 @@ void MachineCopyPropagation::forwardUses(MachineInstr &MI) {
867867
make_range(Copy->getIterator(), std::next(MI.getIterator())))
868868
KMI.clearRegisterKills(CopySrcReg, TRI);
869869

870+
// Attempt to canonicalize/optimize the instruction now its arguments have
871+
// been mutated.
872+
if (TII->optimizeInstruction(MI)) {
873+
LLVM_DEBUG(dbgs() << "MCP: After optimizeInstruction: " << MI << "\n");
874+
}
875+
870876
++NumCopyForwards;
871877
Changed = true;
872878
}

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2344,6 +2344,21 @@ static unsigned getSHXADDShiftAmount(unsigned Opc) {
23442344
}
23452345
}
23462346

2347+
// Returns the shift amount from a SHXADD.UW instruction. Returns 0 if the
2348+
// instruction is not a SHXADD.UW.
2349+
static unsigned getSHXADDUWShiftAmount(unsigned Opc) {
2350+
switch (Opc) {
2351+
default:
2352+
return 0;
2353+
case RISCV::SH1ADD_UW:
2354+
return 1;
2355+
case RISCV::SH2ADD_UW:
2356+
return 2;
2357+
case RISCV::SH3ADD_UW:
2358+
return 3;
2359+
}
2360+
}
2361+
23472362
// Look for opportunities to combine (sh3add Z, (add X, (slli Y, 5))) into
23482363
// (sh3add (sh2add Y, Z), X).
23492364
static bool getSHXADDPatterns(const MachineInstr &Root,
@@ -3850,6 +3865,209 @@ MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI,
38503865
return TargetInstrInfo::commuteInstructionImpl(MI, NewMI, OpIdx1, OpIdx2);
38513866
}
38523867

3868+
bool RISCVInstrInfo::optimizeInstruction(MachineInstr &MI) const {
3869+
switch (MI.getOpcode()) {
3870+
default:
3871+
break;
3872+
case RISCV::OR:
3873+
case RISCV::XOR:
3874+
// Normalize:
3875+
// [x]or rd, zero, rs => [x]or rd, rs, zero
3876+
if (MI.getOperand(1).getReg() == RISCV::X0) {
3877+
MachineOperand MO1 = MI.getOperand(1);
3878+
MI.removeOperand(1);
3879+
MI.addOperand(MO1);
3880+
}
3881+
// [x]or rd, rs, zero => addi rd, rs, 0
3882+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3883+
MI.getOperand(2).ChangeToImmediate(0);
3884+
MI.setDesc(get(RISCV::ADDI));
3885+
return true;
3886+
}
3887+
// xor rd, rs, rs => li rd, rs, 0
3888+
if (MI.getOpcode() == RISCV::XOR && MI.getOperand(1).getReg() == MI.getOperand(2).getReg()) {
3889+
MI.getOperand(2).ChangeToImmediate(0);
3890+
MI.setDesc(get(RISCV::ADDI));
3891+
return true;
3892+
}
3893+
break;
3894+
case RISCV::ADDW:
3895+
// Normalize:
3896+
// addw rd, zero, rs => addw rd, rs, zero
3897+
if (MI.getOperand(1).getReg() == RISCV::X0) {
3898+
MachineOperand MO1 = MI.getOperand(1);
3899+
MI.removeOperand(1);
3900+
MI.addOperand(MO1);
3901+
}
3902+
// addw rd, rs, zero => addiw rd, rs, 0
3903+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3904+
MI.getOperand(2).ChangeToImmediate(0);
3905+
MI.setDesc(get(RISCV::ADDIW));
3906+
return true;
3907+
}
3908+
break;
3909+
case RISCV::SUB:
3910+
case RISCV::PACK:
3911+
case RISCV::PACKW:
3912+
// sub rd, rs, zero => addi rd, rs, 0
3913+
// pack[w] rd, rs, zero => addi rd, rs, zero
3914+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3915+
MI.getOperand(2).ChangeToImmediate(0);
3916+
MI.setDesc(get(RISCV::ADDI));
3917+
return true;
3918+
}
3919+
break;
3920+
case RISCV::SUBW:
3921+
// subw rd, rs, zero => addiw rd, rs, 0
3922+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3923+
MI.getOperand(2).ChangeToImmediate(0);
3924+
MI.setDesc(get(RISCV::ADDIW));
3925+
return true;
3926+
}
3927+
break;
3928+
case RISCV::SH1ADD:
3929+
case RISCV::SH1ADD_UW:
3930+
case RISCV::SH2ADD:
3931+
case RISCV::SH2ADD_UW:
3932+
case RISCV::SH3ADD:
3933+
case RISCV::SH3ADD_UW:
3934+
// shNadd[.uw] rd, zero, rs => addi rd, rs, 0
3935+
if (MI.getOperand(1).getReg() == RISCV::X0) {
3936+
MI.removeOperand(1);
3937+
MI.addOperand(MachineOperand::CreateImm(0));
3938+
MI.setDesc(get(RISCV::ADDI));
3939+
return true;
3940+
}
3941+
// shNadd[.uw] rd, rs, zero => slli[.uw] rd, rs, N
3942+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3943+
MI.removeOperand(2);
3944+
unsigned Opc = MI.getOpcode();
3945+
if (Opc == RISCV::SH1ADD_UW || Opc == RISCV::SH2ADD_UW || Opc == RISCV::SH3ADD_UW) {
3946+
MI.addOperand(MachineOperand::CreateImm(getSHXADDUWShiftAmount(Opc)));
3947+
MI.setDesc(get(RISCV::SLLI_UW));
3948+
return true;
3949+
}
3950+
MI.addOperand(MachineOperand::CreateImm(getSHXADDShiftAmount(Opc)));
3951+
MI.setDesc(get(RISCV::SLLI));
3952+
return true;
3953+
}
3954+
break;
3955+
case RISCV::ANDI:
3956+
// andi rd, zero, C => li rd, 0
3957+
if (MI.getOperand(1).getReg() == RISCV::X0) {
3958+
MI.getOperand(2).setImm(0);
3959+
MI.setDesc(get(RISCV::ADDI));
3960+
return true;
3961+
}
3962+
break;
3963+
case RISCV::AND:
3964+
case RISCV::MUL:
3965+
case RISCV::MULH:
3966+
case RISCV::MULHSU:
3967+
case RISCV::MULHU:
3968+
case RISCV::MULW:
3969+
// and rd, rs, zero => li rd, 0
3970+
// and rd, zero, rs => li rd, 0
3971+
// mul* rd, rs, zero => li rd, 0
3972+
// mul* rd, zero, rs => li rd, 0
3973+
if (MI.getOperand(1).getReg() == RISCV::X0) {
3974+
MI.removeOperand(2);
3975+
MI.addOperand(MachineOperand::CreateImm(0));
3976+
MI.setDesc(get(RISCV::ADDI));
3977+
return true;
3978+
}
3979+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3980+
MI.removeOperand(1);
3981+
MI.addOperand(MachineOperand::CreateImm(0));
3982+
MI.setDesc(get(RISCV::ADDI));
3983+
return true;
3984+
}
3985+
break;
3986+
case RISCV::SLLI:
3987+
case RISCV::SRLI:
3988+
case RISCV::SRAI:
3989+
case RISCV::SLLIW:
3990+
case RISCV::SRLIW:
3991+
case RISCV::SRAIW:
3992+
case RISCV::SLLI_UW:
3993+
// shiftimm rd, zero, N => li rd, 0
3994+
if (MI.getOperand(1).getReg() == RISCV::X0) {
3995+
MI.getOperand(2).setImm(0);
3996+
MI.setDesc(get(RISCV::ADDI));
3997+
return true;
3998+
}
3999+
break;
4000+
case RISCV::ORI:
4001+
case RISCV::XORI:
4002+
// [x]ori rd, zero, N => li rd, N
4003+
if (MI.getOperand(1).getReg() == RISCV::X0) {
4004+
MI.setDesc(get(RISCV::ADDI));
4005+
return true;
4006+
}
4007+
break;
4008+
case RISCV::SLTIU:
4009+
// seqz rd, zero => li rd, 1
4010+
if (MI.getOperand(1).getReg() == RISCV::X0 && MI.getOperand(2).getImm() == 1) {
4011+
MI.setDesc(get(RISCV::ADDI));
4012+
return true;
4013+
}
4014+
break;
4015+
case RISCV::SLTU:
4016+
case RISCV::ADD_UW:
4017+
// snez rd, zero => li rd, 0
4018+
// zext.w rd, zero => li rd, 0
4019+
if (MI.getOperand(1).getReg() == RISCV::X0 && MI.getOperand(2).getReg() == RISCV::X0) {
4020+
MI.getOperand(2).ChangeToImmediate(0);
4021+
MI.setDesc(get(RISCV::ADDI));
4022+
return true;
4023+
}
4024+
// add.uw rd, zero, rs => add.uw rd, rs, zero (canonical zext.w)
4025+
if (MI.getOpcode() == RISCV::ADD_UW && MI.getOperand(1).getReg() == RISCV::X0) {
4026+
MachineOperand MO1 = MI.getOperand(1);
4027+
MI.removeOperand(1);
4028+
MI.addOperand(MO1);
4029+
}
4030+
break;
4031+
case RISCV::SEXT_H:
4032+
case RISCV::SEXT_B:
4033+
case RISCV::ZEXT_H_RV32:
4034+
case RISCV::ZEXT_H_RV64:
4035+
// sext.[hb] rd, zero => li rd, 0
4036+
// zext.h rd, zero => li rd, 0
4037+
if (MI.getOperand(1).getReg() == RISCV::X0) {
4038+
MI.addOperand(MachineOperand::CreateImm(0));
4039+
MI.setDesc(get(RISCV::ADDI));
4040+
return true;
4041+
}
4042+
break;
4043+
case RISCV::SLL:
4044+
case RISCV::SRL:
4045+
case RISCV::SRA:
4046+
case RISCV::SLLW:
4047+
case RISCV::SRLW:
4048+
case RISCV::SRAW:
4049+
// shift rd, zero, rs => li rd, 0
4050+
if (MI.getOperand(1).getReg() == RISCV::X0) {
4051+
MI.getOperand(2).ChangeToImmediate(0);
4052+
MI.setDesc(get(RISCV::ADDI));
4053+
return true;
4054+
}
4055+
break;
4056+
case RISCV::MIN:
4057+
case RISCV::MINU:
4058+
case RISCV::MAX:
4059+
case RISCV::MAXU:
4060+
// min|max rd, rs, rs => addi rd, rs, 0
4061+
if (MI.getOperand(1).getReg() == MI.getOperand(2).getReg()) {
4062+
MI.getOperand(2).ChangeToImmediate(0);
4063+
MI.setDesc(get(RISCV::ADDI));
4064+
return true;
4065+
}
4066+
break;
4067+
}
4068+
return false;
4069+
}
4070+
38534071
#undef CASE_RVV_OPCODE_UNMASK_LMUL
38544072
#undef CASE_RVV_OPCODE_MASK_LMUL
38554073
#undef CASE_RVV_OPCODE_LMUL

llvm/lib/Target/RISCV/RISCVInstrInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
242242
unsigned OpIdx1,
243243
unsigned OpIdx2) const override;
244244

245+
bool optimizeInstruction(MachineInstr &MI) const override;
246+
245247
MachineInstr *convertToThreeAddress(MachineInstr &MI, LiveVariables *LV,
246248
LiveIntervals *LIS) const override;
247249

0 commit comments

Comments
 (0)