@@ -819,92 +819,87 @@ void RISCVInstrInfo::loadRegFromStackSlot(
819819 .setMIFlag (Flags);
820820 }
821821}
822+ std::optional<unsigned > getFoldedOpcode (MachineFunction &MF, MachineInstr &MI,
823+ ArrayRef<unsigned > Ops,
824+ const RISCVSubtarget &ST) {
822825
823- MachineInstr *RISCVInstrInfo::foldMemoryOperandImpl (
824- MachineFunction &MF, MachineInstr &MI, ArrayRef<unsigned > Ops,
825- MachineBasicBlock::iterator InsertPt, int FrameIndex, LiveIntervals *LIS,
826- VirtRegMap *VRM) const {
827826 // The below optimizations narrow the load so they are only valid for little
828827 // endian.
829828 // TODO: Support big endian by adding an offset into the frame object?
830829 if (MF.getDataLayout ().isBigEndian ())
831- return nullptr ;
830+ return std:: nullopt ;
832831
833832 // Fold load from stack followed by sext.b/sext.h/sext.w/zext.b/zext.h/zext.w.
834833 if (Ops.size () != 1 || Ops[0 ] != 1 )
835- return nullptr ;
834+ return std:: nullopt ;
836835
837- unsigned LoadOpc;
838836 switch (MI.getOpcode ()) {
839837 default :
840- if (RISCV::isSEXT_W (MI)) {
841- LoadOpc = RISCV::LW;
842- break ;
843- }
844- if (RISCV::isZEXT_W (MI)) {
845- LoadOpc = RISCV::LWU;
846- break ;
847- }
848- if (RISCV::isZEXT_B (MI)) {
849- LoadOpc = RISCV::LBU;
850- break ;
851- }
852- if (RISCV::getRVVMCOpcode (MI.getOpcode ()) == RISCV::VMV_X_S) {
853- unsigned Log2SEW =
854- MI.getOperand (RISCVII::getSEWOpNum (MI.getDesc ())).getImm ();
855- if (STI.getXLen () < (1U << Log2SEW))
856- return nullptr ;
857- switch (Log2SEW) {
858- case 3 :
859- LoadOpc = RISCV::LB;
860- break ;
861- case 4 :
862- LoadOpc = RISCV::LH;
863- break ;
864- case 5 :
865- LoadOpc = RISCV::LW;
866- break ;
867- case 6 :
868- LoadOpc = RISCV::LD;
869- break ;
870- default :
871- llvm_unreachable (" Unexpected SEW" );
872- }
873- break ;
874- }
875- if (RISCV::getRVVMCOpcode (MI.getOpcode ()) == RISCV::VFMV_F_S) {
876- unsigned Log2SEW =
877- MI.getOperand (RISCVII::getSEWOpNum (MI.getDesc ())).getImm ();
878- switch (Log2SEW) {
879- case 4 :
880- LoadOpc = RISCV::FLH;
881- break ;
882- case 5 :
883- LoadOpc = RISCV::FLW;
884- break ;
885- case 6 :
886- LoadOpc = RISCV::FLD;
887- break ;
888- default :
889- llvm_unreachable (" Unexpected SEW" );
890- }
891- break ;
892- }
893- return nullptr ;
894- case RISCV::SEXT_H:
895- LoadOpc = RISCV::LH;
838+ if (RISCV::isSEXT_W (MI))
839+ return RISCV::LW;
840+ if (RISCV::isZEXT_W (MI))
841+ return RISCV::LWU;
842+ if (RISCV::isZEXT_B (MI))
843+ return RISCV::LBU;
896844 break ;
845+ case RISCV::SEXT_H:
846+ return RISCV::LH;
897847 case RISCV::SEXT_B:
898- LoadOpc = RISCV::LB;
899- break ;
848+ return RISCV::LB;
900849 case RISCV::ZEXT_H_RV32:
901850 case RISCV::ZEXT_H_RV64:
902- LoadOpc = RISCV::LHU;
903- break ;
851+ return RISCV::LHU;
852+ }
853+
854+ switch (RISCV::getRVVMCOpcode (MI.getOpcode ())) {
855+ default :
856+ return std::nullopt ;
857+ case RISCV::VMV_X_S: {
858+ unsigned Log2SEW =
859+ MI.getOperand (RISCVII::getSEWOpNum (MI.getDesc ())).getImm ();
860+ if (ST.getXLen () < (1U << Log2SEW))
861+ return std::nullopt ;
862+ switch (Log2SEW) {
863+ case 3 :
864+ return RISCV::LB;
865+ case 4 :
866+ return RISCV::LH;
867+ case 5 :
868+ return RISCV::LW;
869+ case 6 :
870+ return RISCV::LD;
871+ default :
872+ llvm_unreachable (" Unexpected SEW" );
873+ }
904874 }
875+ case RISCV::VFMV_F_S: {
876+ unsigned Log2SEW =
877+ MI.getOperand (RISCVII::getSEWOpNum (MI.getDesc ())).getImm ();
878+ switch (Log2SEW) {
879+ case 4 :
880+ return RISCV::FLH;
881+ case 5 :
882+ return RISCV::FLW;
883+ case 6 :
884+ return RISCV::FLD;
885+ default :
886+ llvm_unreachable (" Unexpected SEW" );
887+ }
888+ }
889+ }
890+ }
905891
892+ // This is the version used during inline spilling
893+ MachineInstr *RISCVInstrInfo::foldMemoryOperandImpl (
894+ MachineFunction &MF, MachineInstr &MI, ArrayRef<unsigned > Ops,
895+ MachineBasicBlock::iterator InsertPt, int FrameIndex, LiveIntervals *LIS,
896+ VirtRegMap *VRM) const {
897+
898+ std::optional<unsigned > LoadOpc = getFoldedOpcode (MF, MI, Ops, STI);
899+ if (!LoadOpc)
900+ return nullptr ;
906901 Register DstReg = MI.getOperand (0 ).getReg ();
907- return BuildMI (*MI.getParent (), InsertPt, MI.getDebugLoc (), get (LoadOpc),
902+ return BuildMI (*MI.getParent (), InsertPt, MI.getDebugLoc (), get (* LoadOpc),
908903 DstReg)
909904 .addFrameIndex (FrameIndex)
910905 .addImm (0 );
0 commit comments