@@ -57,6 +57,17 @@ class RISCVInstructionSelector : public InstructionSelector {
5757 const TargetRegisterClass *
5858 getRegClassForTypeOnBank (LLT Ty, const RegisterBank &RB) const ;
5959
60+ static constexpr unsigned MaxRecursionDepth = 6 ;
61+
62+ bool hasAllNBitUsers (const MachineInstr &MI, unsigned Bits,
63+ const unsigned Depth = 0 ) const ;
64+ bool hasAllHUsers (const MachineInstr &MI) const {
65+ return hasAllNBitUsers (MI, 16 );
66+ }
67+ bool hasAllWUsers (const MachineInstr &MI) const {
68+ return hasAllNBitUsers (MI, 32 );
69+ }
70+
6071 bool isRegInGprb (Register Reg) const ;
6172 bool isRegInFprb (Register Reg) const ;
6273
@@ -184,6 +195,79 @@ RISCVInstructionSelector::RISCVInstructionSelector(
184195{
185196}
186197
198+ // Mimics optimizations in ISel and RISCVOptWInst Pass
199+ bool RISCVInstructionSelector::hasAllNBitUsers (const MachineInstr &MI,
200+ unsigned Bits,
201+ const unsigned Depth) const {
202+
203+ assert ((MI.getOpcode () == TargetOpcode::G_ADD ||
204+ MI.getOpcode () == TargetOpcode::G_SUB ||
205+ MI.getOpcode () == TargetOpcode::G_MUL ||
206+ MI.getOpcode () == TargetOpcode::G_SHL ||
207+ MI.getOpcode () == TargetOpcode::G_LSHR ||
208+ MI.getOpcode () == TargetOpcode::G_AND ||
209+ MI.getOpcode () == TargetOpcode::G_OR ||
210+ MI.getOpcode () == TargetOpcode::G_XOR ||
211+ MI.getOpcode () == TargetOpcode::G_SEXT_INREG || Depth != 0 ) &&
212+ " Unexpected opcode" );
213+
214+ if (Depth >= RISCVInstructionSelector::MaxRecursionDepth)
215+ return false ;
216+
217+ auto DestReg = MI.getOperand (0 ).getReg ();
218+ for (auto &UserOp : MRI->use_nodbg_operands (DestReg)) {
219+ assert (UserOp.getParent () && " UserOp must have a parent" );
220+ const MachineInstr &UserMI = *UserOp.getParent ();
221+ unsigned OpIdx = UserOp.getOperandNo ();
222+
223+ switch (UserMI.getOpcode ()) {
224+ default :
225+ return false ;
226+ case RISCV::ADDW:
227+ case RISCV::ADDIW:
228+ case RISCV::SUBW:
229+ if (Bits >= 32 )
230+ break ;
231+ return false ;
232+ case RISCV::SLL:
233+ case RISCV::SRA:
234+ case RISCV::SRL:
235+ // Shift amount operands only use log2(Xlen) bits.
236+ if (OpIdx == 2 && Bits >= Log2_32 (Subtarget->getXLen ()))
237+ break ;
238+ return false ;
239+ case RISCV::SLLI:
240+ // SLLI only uses the lower (XLen - ShAmt) bits.
241+ if (Bits >= Subtarget->getXLen () - UserMI.getOperand (2 ).getImm ())
242+ break ;
243+ return false ;
244+ case RISCV::ANDI:
245+ if (Bits >= (unsigned )llvm::bit_width<uint64_t >(
246+ (uint64_t )UserMI.getOperand (2 ).getImm ()))
247+ break ;
248+ goto RecCheck;
249+ case RISCV::AND:
250+ case RISCV::OR:
251+ case RISCV::XOR:
252+ RecCheck:
253+ if (hasAllNBitUsers (UserMI, Bits, Depth + 1 ))
254+ break ;
255+ return false ;
256+ case RISCV::SRLI: {
257+ unsigned ShAmt = UserMI.getOperand (2 ).getImm ();
258+ // If we are shifting right by less than Bits, and users don't demand any
259+ // bits that were shifted into [Bits-1:0], then we can consider this as an
260+ // N-Bit user.
261+ if (Bits > ShAmt && hasAllNBitUsers (UserMI, Bits - ShAmt, Depth + 1 ))
262+ break ;
263+ return false ;
264+ }
265+ }
266+ }
267+
268+ return true ;
269+ }
270+
187271InstructionSelector::ComplexRendererFns
188272RISCVInstructionSelector::selectShiftMask (MachineOperand &Root,
189273 unsigned ShiftWidth) const {
0 commit comments