@@ -119,22 +119,9 @@ class SIFoldOperandsImpl {
119119 MachineOperand *OpToFold) const ;
120120 bool isUseSafeToFold (const MachineInstr &MI,
121121 const MachineOperand &UseMO) const ;
122-
123- const TargetRegisterClass *getRegSeqInit (
124- MachineInstr &RegSeq,
125- SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs) const ;
126-
127- const TargetRegisterClass *
122+ bool
128123 getRegSeqInit (SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
129- Register UseReg) const ;
130-
131- std::pair<MachineOperand *, const TargetRegisterClass *>
132- isRegSeqSplat (MachineInstr &RegSeg) const ;
133-
134- MachineOperand *tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
135- MachineOperand *SplatVal,
136- const TargetRegisterClass *SplatRC) const ;
137-
124+ Register UseReg, uint8_t OpTy) const ;
138125 bool tryToFoldACImm (MachineOperand &OpToFold, MachineInstr *UseMI,
139126 unsigned UseOpIdx,
140127 SmallVectorImpl<FoldCandidate> &FoldList) const ;
@@ -838,24 +825,19 @@ static MachineOperand *lookUpCopyChain(const SIInstrInfo &TII,
838825 return Sub;
839826}
840827
841- const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit (
842- MachineInstr &RegSeq,
843- SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs) const {
844-
845- assert (RegSeq.isRegSequence ());
846-
847- const TargetRegisterClass *RC = nullptr ;
848-
849- for (unsigned I = 1 , E = RegSeq.getNumExplicitOperands (); I != E; I += 2 ) {
850- MachineOperand &SrcOp = RegSeq.getOperand (I);
851- unsigned SubRegIdx = RegSeq.getOperand (I + 1 ).getImm ();
828+ // Find a def of the UseReg, check if it is a reg_sequence and find initializers
829+ // for each subreg, tracking it to foldable inline immediate if possible.
830+ // Returns true on success.
831+ bool SIFoldOperandsImpl::getRegSeqInit (
832+ SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
833+ Register UseReg, uint8_t OpTy) const {
834+ MachineInstr *Def = MRI->getVRegDef (UseReg);
835+ if (!Def || !Def->isRegSequence ())
836+ return false ;
852837
853- // Only accept reg_sequence with uniform reg class inputs for simplicity.
854- const TargetRegisterClass *OpRC = getRegOpRC (*MRI, *TRI, SrcOp);
855- if (!RC)
856- RC = OpRC;
857- else if (!TRI->getCommonSubClass (RC, OpRC))
858- return nullptr ;
838+ for (unsigned I = 1 , E = Def->getNumExplicitOperands (); I != E; I += 2 ) {
839+ MachineOperand &SrcOp = Def->getOperand (I);
840+ unsigned SubRegIdx = Def->getOperand (I + 1 ).getImm ();
859841
860842 if (SrcOp.getSubReg ()) {
861843 // TODO: Handle subregister compose
@@ -864,106 +846,16 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
864846 }
865847
866848 MachineOperand *DefSrc = lookUpCopyChain (*TII, *MRI, SrcOp.getReg ());
867- if (DefSrc && (DefSrc->isReg () || DefSrc->isImm ())) {
849+ if (DefSrc && (DefSrc->isReg () ||
850+ (DefSrc->isImm () && TII->isInlineConstant (*DefSrc, OpTy)))) {
868851 Defs.emplace_back (DefSrc, SubRegIdx);
869852 continue ;
870853 }
871854
872855 Defs.emplace_back (&SrcOp, SubRegIdx);
873856 }
874857
875- return RC;
876- }
877-
878- // Find a def of the UseReg, check if it is a reg_sequence and find initializers
879- // for each subreg, tracking it to an immediate if possible. Returns the
880- // register class of the inputs on success.
881- const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit (
882- SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
883- Register UseReg) const {
884- MachineInstr *Def = MRI->getVRegDef (UseReg);
885- if (!Def || !Def->isRegSequence ())
886- return nullptr ;
887-
888- return getRegSeqInit (*Def, Defs);
889- }
890-
891- std::pair<MachineOperand *, const TargetRegisterClass *>
892- SIFoldOperandsImpl::isRegSeqSplat (MachineInstr &RegSeq) const {
893- SmallVector<std::pair<MachineOperand *, unsigned >, 32 > Defs;
894- const TargetRegisterClass *SrcRC = getRegSeqInit (RegSeq, Defs);
895- if (!SrcRC)
896- return {};
897-
898- // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
899- // every other other element is 0 for 64-bit immediates)
900- int64_t Imm;
901- for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
902- const MachineOperand *Op = Defs[I].first ;
903- if (!Op->isImm ())
904- return {};
905-
906- int64_t SubImm = Op->getImm ();
907- if (!I) {
908- Imm = SubImm;
909- continue ;
910- }
911- if (Imm != SubImm)
912- return {}; // Can only fold splat constants
913- }
914-
915- return {Defs[0 ].first , SrcRC};
916- }
917-
918- MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat (
919- MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
920- const TargetRegisterClass *SplatRC) const {
921- const MCInstrDesc &Desc = UseMI->getDesc ();
922- if (UseOpIdx >= Desc.getNumOperands ())
923- return nullptr ;
924-
925- // Filter out unhandled pseudos.
926- if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
927- return nullptr ;
928-
929- int16_t RCID = Desc.operands ()[UseOpIdx].RegClass ;
930- if (RCID == -1 )
931- return nullptr ;
932-
933- // Special case 0/-1, since when interpreted as a 64-bit element both halves
934- // have the same bits. Effectively this code does not handle 64-bit element
935- // operands correctly, as the incoming 64-bit constants are already split into
936- // 32-bit sequence elements.
937- //
938- // TODO: We should try to figure out how to interpret the reg_sequence as a
939- // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
940- // constants.
941- if (SplatVal->getImm () != 0 && SplatVal->getImm () != -1 ) {
942- const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
943- // We need to figure out the scalar type read by the operand. e.g. the MFMA
944- // operand will be AReg_128, and we want to check if it's compatible with an
945- // AReg_32 constant.
946- uint8_t OpTy = Desc.operands ()[UseOpIdx].OperandType ;
947- switch (OpTy) {
948- case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
949- case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
950- OpRC = TRI->getSubRegisterClass (OpRC, AMDGPU::sub0);
951- break ;
952- case AMDGPU::OPERAND_REG_INLINE_AC_FP64:
953- OpRC = TRI->getSubRegisterClass (OpRC, AMDGPU::sub0_sub1);
954- break ;
955- default :
956- return nullptr ;
957- }
958-
959- if (!TRI->getCommonSubClass (OpRC, SplatRC))
960- return nullptr ;
961- }
962-
963- if (!TII->isOperandLegal (*UseMI, UseOpIdx, SplatVal))
964- return nullptr ;
965-
966- return SplatVal;
858+ return true ;
967859}
968860
969861bool SIFoldOperandsImpl::tryToFoldACImm (
@@ -977,6 +869,7 @@ bool SIFoldOperandsImpl::tryToFoldACImm(
977869 if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
978870 return false ;
979871
872+ uint8_t OpTy = Desc.operands ()[UseOpIdx].OperandType ;
980873 MachineOperand &UseOp = UseMI->getOperand (UseOpIdx);
981874 if (OpToFold.isImm ()) {
982875 if (unsigned UseSubReg = UseOp.getSubReg ()) {
@@ -1023,7 +916,31 @@ bool SIFoldOperandsImpl::tryToFoldACImm(
1023916 }
1024917 }
1025918
1026- return false ;
919+ SmallVector<std::pair<MachineOperand*, unsigned >, 32 > Defs;
920+ if (!getRegSeqInit (Defs, UseReg, OpTy))
921+ return false ;
922+
923+ int32_t Imm;
924+ for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
925+ const MachineOperand *Op = Defs[I].first ;
926+ if (!Op->isImm ())
927+ return false ;
928+
929+ auto SubImm = Op->getImm ();
930+ if (!I) {
931+ Imm = SubImm;
932+ if (!TII->isInlineConstant (*Op, OpTy) ||
933+ !TII->isOperandLegal (*UseMI, UseOpIdx, Op))
934+ return false ;
935+
936+ continue ;
937+ }
938+ if (Imm != SubImm)
939+ return false ; // Can only fold splat constants
940+ }
941+
942+ appendFoldCandidate (FoldList, UseMI, UseOpIdx, Defs[0 ].first );
943+ return true ;
1027944}
1028945
1029946void SIFoldOperandsImpl::foldOperand (
@@ -1053,34 +970,21 @@ void SIFoldOperandsImpl::foldOperand(
1053970 Register RegSeqDstReg = UseMI->getOperand (0 ).getReg ();
1054971 unsigned RegSeqDstSubReg = UseMI->getOperand (UseOpIdx + 1 ).getImm ();
1055972
1056- MachineOperand *SplatVal;
1057- const TargetRegisterClass *SplatRC;
1058- std::tie (SplatVal, SplatRC) = isRegSeqSplat (*UseMI);
1059-
1060973 // Grab the use operands first
1061974 SmallVector<MachineOperand *, 4 > UsesToProcess (
1062975 llvm::make_pointer_range (MRI->use_nodbg_operands (RegSeqDstReg)));
1063976 for (auto *RSUse : UsesToProcess) {
1064977 MachineInstr *RSUseMI = RSUse->getParent ();
1065- unsigned OpNo = RSUseMI->getOperandNo (RSUse);
1066978
1067- if (SplatVal) {
1068- if (MachineOperand *Foldable =
1069- tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1070- appendFoldCandidate (FoldList, RSUseMI, OpNo, Foldable);
1071- continue ;
1072- }
1073- }
1074-
1075- if (RSUse->getSubReg () != RegSeqDstSubReg)
979+ if (tryToFoldACImm (UseMI->getOperand (0 ), RSUseMI,
980+ RSUseMI->getOperandNo (RSUse), FoldList))
1076981 continue ;
1077982
1078- if (tryToFoldACImm (UseMI-> getOperand ( 0 ), RSUseMI, OpNo, FoldList) )
983+ if (RSUse-> getSubReg () != RegSeqDstSubReg )
1079984 continue ;
1080985
1081- foldOperand (OpToFold, RSUseMI, OpNo , FoldList, CopiesToReplace);
986+ foldOperand (OpToFold, RSUseMI, RSUseMI-> getOperandNo (RSUse) , FoldList, CopiesToReplace);
1082987 }
1083-
1084988 return ;
1085989 }
1086990
@@ -2232,7 +2136,7 @@ bool SIFoldOperandsImpl::tryFoldRegSequence(MachineInstr &MI) {
22322136 return false ;
22332137
22342138 SmallVector<std::pair<MachineOperand*, unsigned >, 32 > Defs;
2235- if (!getRegSeqInit (Defs, Reg))
2139+ if (!getRegSeqInit (Defs, Reg, MCOI::OPERAND_REGISTER ))
22362140 return false ;
22372141
22382142 for (auto &[Op, SubIdx] : Defs) {
0 commit comments