@@ -119,22 +119,9 @@ class SIFoldOperandsImpl {
119119 MachineOperand *OpToFold) const ;
120120 bool isUseSafeToFold (const MachineInstr &MI,
121121 const MachineOperand &UseMO) const ;
122-
123- const TargetRegisterClass *getRegSeqInit (
124- MachineInstr &RegSeq,
125- SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs) const ;
126-
127- const TargetRegisterClass *
122+ bool
128123 getRegSeqInit (SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
129- Register UseReg) const ;
130-
131- std::pair<MachineOperand *, const TargetRegisterClass *>
132- isRegSeqSplat (MachineInstr &RegSeg) const ;
133-
134- MachineOperand *tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
135- MachineOperand *SplatVal,
136- const TargetRegisterClass *SplatRC) const ;
137-
124+ Register UseReg, uint8_t OpTy) const ;
138125 bool tryToFoldACImm (MachineOperand &OpToFold, MachineInstr *UseMI,
139126 unsigned UseOpIdx,
140127 SmallVectorImpl<FoldCandidate> &FoldList) const ;
@@ -814,24 +801,19 @@ static MachineOperand *lookUpCopyChain(const SIInstrInfo &TII,
814801 return Sub;
815802}
816803
817- const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit (
818- MachineInstr &RegSeq,
819- SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs) const {
820-
821- assert (RegSeq.isRegSequence ());
822-
823- const TargetRegisterClass *RC = nullptr ;
824-
825- for (unsigned I = 1 , E = RegSeq.getNumExplicitOperands (); I != E; I += 2 ) {
826- MachineOperand &SrcOp = RegSeq.getOperand (I);
827- unsigned SubRegIdx = RegSeq.getOperand (I + 1 ).getImm ();
804+ // Find a def of the UseReg, check if it is a reg_sequence and find initializers
805+ // for each subreg, tracking it to foldable inline immediate if possible.
806+ // Returns true on success.
807+ bool SIFoldOperandsImpl::getRegSeqInit (
808+ SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
809+ Register UseReg, uint8_t OpTy) const {
810+ MachineInstr *Def = MRI->getVRegDef (UseReg);
811+ if (!Def || !Def->isRegSequence ())
812+ return false ;
828813
829- // Only accept reg_sequence with uniform reg class inputs for simplicity.
830- const TargetRegisterClass *OpRC = getRegOpRC (*MRI, *TRI, SrcOp);
831- if (!RC)
832- RC = OpRC;
833- else if (!TRI->getCommonSubClass (RC, OpRC))
834- return nullptr ;
814+ for (unsigned I = 1 , E = Def->getNumExplicitOperands (); I != E; I += 2 ) {
815+ MachineOperand &SrcOp = Def->getOperand (I);
816+ unsigned SubRegIdx = Def->getOperand (I + 1 ).getImm ();
835817
836818 if (SrcOp.getSubReg ()) {
837819 // TODO: Handle subregister compose
@@ -840,106 +822,16 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
840822 }
841823
842824 MachineOperand *DefSrc = lookUpCopyChain (*TII, *MRI, SrcOp.getReg ());
843- if (DefSrc && (DefSrc->isReg () || DefSrc->isImm ())) {
825+ if (DefSrc && (DefSrc->isReg () ||
826+ (DefSrc->isImm () && TII->isInlineConstant (*DefSrc, OpTy)))) {
844827 Defs.emplace_back (DefSrc, SubRegIdx);
845828 continue ;
846829 }
847830
848831 Defs.emplace_back (&SrcOp, SubRegIdx);
849832 }
850833
851- return RC;
852- }
853-
854- // Find a def of the UseReg, check if it is a reg_sequence and find initializers
855- // for each subreg, tracking it to an immediate if possible. Returns the
856- // register class of the inputs on success.
857- const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit (
858- SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
859- Register UseReg) const {
860- MachineInstr *Def = MRI->getVRegDef (UseReg);
861- if (!Def || !Def->isRegSequence ())
862- return nullptr ;
863-
864- return getRegSeqInit (*Def, Defs);
865- }
866-
867- std::pair<MachineOperand *, const TargetRegisterClass *>
868- SIFoldOperandsImpl::isRegSeqSplat (MachineInstr &RegSeq) const {
869- SmallVector<std::pair<MachineOperand *, unsigned >, 32 > Defs;
870- const TargetRegisterClass *SrcRC = getRegSeqInit (RegSeq, Defs);
871- if (!SrcRC)
872- return {};
873-
874- // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
875- // every other other element is 0 for 64-bit immediates)
876- int64_t Imm;
877- for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
878- const MachineOperand *Op = Defs[I].first ;
879- if (!Op->isImm ())
880- return {};
881-
882- int64_t SubImm = Op->getImm ();
883- if (!I) {
884- Imm = SubImm;
885- continue ;
886- }
887- if (Imm != SubImm)
888- return {}; // Can only fold splat constants
889- }
890-
891- return {Defs[0 ].first , SrcRC};
892- }
893-
894- MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat (
895- MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
896- const TargetRegisterClass *SplatRC) const {
897- const MCInstrDesc &Desc = UseMI->getDesc ();
898- if (UseOpIdx >= Desc.getNumOperands ())
899- return nullptr ;
900-
901- // Filter out unhandled pseudos.
902- if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
903- return nullptr ;
904-
905- int16_t RCID = Desc.operands ()[UseOpIdx].RegClass ;
906- if (RCID == -1 )
907- return nullptr ;
908-
909- // Special case 0/-1, since when interpreted as a 64-bit element both halves
910- // have the same bits. Effectively this code does not handle 64-bit element
911- // operands correctly, as the incoming 64-bit constants are already split into
912- // 32-bit sequence elements.
913- //
914- // TODO: We should try to figure out how to interpret the reg_sequence as a
915- // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
916- // constants.
917- if (SplatVal->getImm () != 0 && SplatVal->getImm () != -1 ) {
918- const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
919- // We need to figure out the scalar type read by the operand. e.g. the MFMA
920- // operand will be AReg_128, and we want to check if it's compatible with an
921- // AReg_32 constant.
922- uint8_t OpTy = Desc.operands ()[UseOpIdx].OperandType ;
923- switch (OpTy) {
924- case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
925- case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
926- OpRC = TRI->getSubRegisterClass (OpRC, AMDGPU::sub0);
927- break ;
928- case AMDGPU::OPERAND_REG_INLINE_AC_FP64:
929- OpRC = TRI->getSubRegisterClass (OpRC, AMDGPU::sub0_sub1);
930- break ;
931- default :
932- return nullptr ;
933- }
934-
935- if (!TRI->getCommonSubClass (OpRC, SplatRC))
936- return nullptr ;
937- }
938-
939- if (!TII->isOperandLegal (*UseMI, UseOpIdx, SplatVal))
940- return nullptr ;
941-
942- return SplatVal;
834+ return true ;
943835}
944836
945837bool SIFoldOperandsImpl::tryToFoldACImm (
@@ -953,6 +845,7 @@ bool SIFoldOperandsImpl::tryToFoldACImm(
953845 if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
954846 return false ;
955847
848+ uint8_t OpTy = Desc.operands ()[UseOpIdx].OperandType ;
956849 MachineOperand &UseOp = UseMI->getOperand (UseOpIdx);
957850 if (OpToFold.isImm ()) {
958851 if (unsigned UseSubReg = UseOp.getSubReg ()) {
@@ -999,7 +892,31 @@ bool SIFoldOperandsImpl::tryToFoldACImm(
999892 }
1000893 }
1001894
1002- return false ;
895+ SmallVector<std::pair<MachineOperand*, unsigned >, 32 > Defs;
896+ if (!getRegSeqInit (Defs, UseReg, OpTy))
897+ return false ;
898+
899+ int32_t Imm;
900+ for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
901+ const MachineOperand *Op = Defs[I].first ;
902+ if (!Op->isImm ())
903+ return false ;
904+
905+ auto SubImm = Op->getImm ();
906+ if (!I) {
907+ Imm = SubImm;
908+ if (!TII->isInlineConstant (*Op, OpTy) ||
909+ !TII->isOperandLegal (*UseMI, UseOpIdx, Op))
910+ return false ;
911+
912+ continue ;
913+ }
914+ if (Imm != SubImm)
915+ return false ; // Can only fold splat constants
916+ }
917+
918+ appendFoldCandidate (FoldList, UseMI, UseOpIdx, Defs[0 ].first );
919+ return true ;
1003920}
1004921
1005922void SIFoldOperandsImpl::foldOperand (
@@ -1029,34 +946,21 @@ void SIFoldOperandsImpl::foldOperand(
1029946 Register RegSeqDstReg = UseMI->getOperand (0 ).getReg ();
1030947 unsigned RegSeqDstSubReg = UseMI->getOperand (UseOpIdx + 1 ).getImm ();
1031948
1032- MachineOperand *SplatVal;
1033- const TargetRegisterClass *SplatRC;
1034- std::tie (SplatVal, SplatRC) = isRegSeqSplat (*UseMI);
1035-
1036949 // Grab the use operands first
1037950 SmallVector<MachineOperand *, 4 > UsesToProcess (
1038951 llvm::make_pointer_range (MRI->use_nodbg_operands (RegSeqDstReg)));
1039952 for (auto *RSUse : UsesToProcess) {
1040953 MachineInstr *RSUseMI = RSUse->getParent ();
1041- unsigned OpNo = RSUseMI->getOperandNo (RSUse);
1042954
1043- if (SplatVal) {
1044- if (MachineOperand *Foldable =
1045- tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1046- appendFoldCandidate (FoldList, RSUseMI, OpNo, Foldable);
1047- continue ;
1048- }
1049- }
1050-
1051- if (RSUse->getSubReg () != RegSeqDstSubReg)
955+ if (tryToFoldACImm (UseMI->getOperand (0 ), RSUseMI,
956+ RSUseMI->getOperandNo (RSUse), FoldList))
1052957 continue ;
1053958
1054- if (tryToFoldACImm (UseMI-> getOperand ( 0 ), RSUseMI, OpNo, FoldList) )
959+ if (RSUse-> getSubReg () != RegSeqDstSubReg )
1055960 continue ;
1056961
1057- foldOperand (OpToFold, RSUseMI, OpNo , FoldList, CopiesToReplace);
962+ foldOperand (OpToFold, RSUseMI, RSUseMI-> getOperandNo (RSUse) , FoldList, CopiesToReplace);
1058963 }
1059-
1060964 return ;
1061965 }
1062966
@@ -2217,7 +2121,7 @@ bool SIFoldOperandsImpl::tryFoldRegSequence(MachineInstr &MI) {
22172121 return false ;
22182122
22192123 SmallVector<std::pair<MachineOperand*, unsigned >, 32 > Defs;
2220- if (!getRegSeqInit (Defs, Reg))
2124+ if (!getRegSeqInit (Defs, Reg, MCOI::OPERAND_REGISTER ))
22212125 return false ;
22222126
22232127 for (auto &[Op, SubIdx] : Defs) {
0 commit comments