@@ -119,9 +119,22 @@ class SIFoldOperandsImpl {
119119 MachineOperand *OpToFold) const ;
120120 bool isUseSafeToFold (const MachineInstr &MI,
121121 const MachineOperand &UseMO) const ;
122- bool
122+
123+ const TargetRegisterClass *getRegSeqInit (
124+ MachineInstr &RegSeq,
125+ SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs) const ;
126+
127+ const TargetRegisterClass *
123128 getRegSeqInit (SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
124- Register UseReg, uint8_t OpTy) const ;
129+ Register UseReg) const ;
130+
131+ std::pair<MachineOperand *, const TargetRegisterClass *>
132+ isRegSeqSplat (MachineInstr &RegSeg) const ;
133+
134+ MachineOperand *tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
135+ MachineOperand *SplatVal,
136+ const TargetRegisterClass *SplatRC) const ;
137+
125138 bool tryToFoldACImm (MachineOperand &OpToFold, MachineInstr *UseMI,
126139 unsigned UseOpIdx,
127140 SmallVectorImpl<FoldCandidate> &FoldList) const ;
@@ -825,19 +838,24 @@ static MachineOperand *lookUpCopyChain(const SIInstrInfo &TII,
825838 return Sub;
826839}
827840
828- // Find a def of the UseReg, check if it is a reg_sequence and find initializers
829- // for each subreg, tracking it to foldable inline immediate if possible.
830- // Returns true on success.
831- bool SIFoldOperandsImpl::getRegSeqInit (
832- SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
833- Register UseReg, uint8_t OpTy) const {
834- MachineInstr *Def = MRI->getVRegDef (UseReg);
835- if (!Def || !Def->isRegSequence ())
836- return false ;
841+ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit (
842+ MachineInstr &RegSeq,
843+ SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs) const {
837844
838- for (unsigned I = 1 , E = Def->getNumExplicitOperands (); I != E; I += 2 ) {
839- MachineOperand &SrcOp = Def->getOperand (I);
840- unsigned SubRegIdx = Def->getOperand (I + 1 ).getImm ();
845+ assert (RegSeq.isRegSequence ());
846+
847+ const TargetRegisterClass *RC = nullptr ;
848+
849+ for (unsigned I = 1 , E = RegSeq.getNumExplicitOperands (); I != E; I += 2 ) {
850+ MachineOperand &SrcOp = RegSeq.getOperand (I);
851+ unsigned SubRegIdx = RegSeq.getOperand (I + 1 ).getImm ();
852+
853+ // Only accept reg_sequence with uniform reg class inputs for simplicity.
854+ const TargetRegisterClass *OpRC = getRegOpRC (*MRI, *TRI, SrcOp);
855+ if (!RC)
856+ RC = OpRC;
857+ else if (!TRI->getCommonSubClass (RC, OpRC))
858+ return nullptr ;
841859
842860 if (SrcOp.getSubReg ()) {
843861 // TODO: Handle subregister compose
@@ -846,16 +864,73 @@ bool SIFoldOperandsImpl::getRegSeqInit(
846864 }
847865
848866 MachineOperand *DefSrc = lookUpCopyChain (*TII, *MRI, SrcOp.getReg ());
849- if (DefSrc && (DefSrc->isReg () ||
850- (DefSrc->isImm () && TII->isInlineConstant (*DefSrc, OpTy)))) {
867+ if (DefSrc && (DefSrc->isReg () || DefSrc->isImm ())) {
851868 Defs.emplace_back (DefSrc, SubRegIdx);
852869 continue ;
853870 }
854871
855872 Defs.emplace_back (&SrcOp, SubRegIdx);
856873 }
857874
858- return true ;
875+ return RC;
876+ }
877+
878+ // Find a def of the UseReg, check if it is a reg_sequence and find initializers
879+ // for each subreg, tracking it to an immediate if possible. Returns the
880+ // register class of the inputs on success.
881+ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit (
882+ SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
883+ Register UseReg) const {
884+ MachineInstr *Def = MRI->getVRegDef (UseReg);
885+ if (!Def || !Def->isRegSequence ())
886+ return nullptr ;
887+
888+ return getRegSeqInit (*Def, Defs);
889+ }
890+
891+ std::pair<MachineOperand *, const TargetRegisterClass *>
892+ SIFoldOperandsImpl::isRegSeqSplat (MachineInstr &RegSeq) const {
893+ SmallVector<std::pair<MachineOperand *, unsigned >, 32 > Defs;
894+ const TargetRegisterClass *SrcRC = getRegSeqInit (RegSeq, Defs);
895+ if (!SrcRC)
896+ return {};
897+
898+ int64_t Imm;
899+ for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
900+ const MachineOperand *Op = Defs[I].first ;
901+ if (!Op->isImm ())
902+ return {};
903+
904+ int64_t SubImm = Op->getImm ();
905+ if (!I) {
906+ Imm = SubImm;
907+ continue ;
908+ }
909+ if (Imm != SubImm)
910+ return {}; // Can only fold splat constants
911+ }
912+
913+ return {Defs[0 ].first , SrcRC};
914+ }
915+
916+ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat (
917+ MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
918+ const TargetRegisterClass *SplatRC) const {
919+ const MCInstrDesc &Desc = UseMI->getDesc ();
920+ if (UseOpIdx >= Desc.getNumOperands ())
921+ return nullptr ;
922+
923+ // Filter out unhandled pseudos.
924+ if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
925+ return nullptr ;
926+
927+ // FIXME: Verify SplatRC is compatible with the use operand
928+ uint8_t OpTy = Desc.operands ()[UseOpIdx].OperandType ;
929+ if (!TII->isInlineConstant (*SplatVal, OpTy) ||
930+ !TII->isOperandLegal (*UseMI, UseOpIdx, SplatVal))
931+ return nullptr ;
932+
933+ return SplatVal;
859934}
860935
861936bool SIFoldOperandsImpl::tryToFoldACImm (
@@ -869,7 +944,6 @@ bool SIFoldOperandsImpl::tryToFoldACImm(
869944 if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
870945 return false ;
871946
872- uint8_t OpTy = Desc.operands ()[UseOpIdx].OperandType ;
873947 MachineOperand &UseOp = UseMI->getOperand (UseOpIdx);
874948 if (OpToFold.isImm ()) {
875949 if (unsigned UseSubReg = UseOp.getSubReg ()) {
@@ -916,31 +990,7 @@ bool SIFoldOperandsImpl::tryToFoldACImm(
916990 }
917991 }
918992
919- SmallVector<std::pair<MachineOperand*, unsigned >, 32 > Defs;
920- if (!getRegSeqInit (Defs, UseReg, OpTy))
921- return false ;
922-
923- int32_t Imm;
924- for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
925- const MachineOperand *Op = Defs[I].first ;
926- if (!Op->isImm ())
927- return false ;
928-
929- auto SubImm = Op->getImm ();
930- if (!I) {
931- Imm = SubImm;
932- if (!TII->isInlineConstant (*Op, OpTy) ||
933- !TII->isOperandLegal (*UseMI, UseOpIdx, Op))
934- return false ;
935-
936- continue ;
937- }
938- if (Imm != SubImm)
939- return false ; // Can only fold splat constants
940- }
941-
942- appendFoldCandidate (FoldList, UseMI, UseOpIdx, Defs[0 ].first );
943- return true ;
993+ return false ;
944994}
945995
946996void SIFoldOperandsImpl::foldOperand (
@@ -970,14 +1020,26 @@ void SIFoldOperandsImpl::foldOperand(
9701020 Register RegSeqDstReg = UseMI->getOperand (0 ).getReg ();
9711021 unsigned RegSeqDstSubReg = UseMI->getOperand (UseOpIdx + 1 ).getImm ();
9721022
1023+ MachineOperand *SplatVal;
1024+ const TargetRegisterClass *SplatRC;
1025+ std::tie (SplatVal, SplatRC) = isRegSeqSplat (*UseMI);
1026+
9731027 // Grab the use operands first
9741028 SmallVector<MachineOperand *, 4 > UsesToProcess (
9751029 llvm::make_pointer_range (MRI->use_nodbg_operands (RegSeqDstReg)));
9761030 for (auto *RSUse : UsesToProcess) {
9771031 MachineInstr *RSUseMI = RSUse->getParent ();
1032+ unsigned OpNo = RSUseMI->getOperandNo (RSUse);
9781033
979- if (tryToFoldACImm (UseMI->getOperand (0 ), RSUseMI,
980- RSUseMI->getOperandNo (RSUse), FoldList))
1034+ if (SplatVal) {
1035+ if (MachineOperand *Foldable =
1036+ tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1037+ appendFoldCandidate (FoldList, RSUseMI, OpNo, Foldable);
1038+ continue ;
1039+ }
1040+ }
1041+
1042+ if (tryToFoldACImm (UseMI->getOperand (0 ), RSUseMI, OpNo, FoldList))
9811043 continue ;
9821044
9831045 if (RSUse->getSubReg () != RegSeqDstSubReg)
@@ -986,6 +1048,7 @@ void SIFoldOperandsImpl::foldOperand(
9861048 foldOperand (OpToFold, RSUseMI, RSUseMI->getOperandNo (RSUse), FoldList,
9871049 CopiesToReplace);
9881050 }
1051+
9891052 return ;
9901053 }
9911054
@@ -2137,7 +2200,7 @@ bool SIFoldOperandsImpl::tryFoldRegSequence(MachineInstr &MI) {
21372200 return false ;
21382201
21392202 SmallVector<std::pair<MachineOperand*, unsigned >, 32 > Defs;
2140- if (!getRegSeqInit (Defs, Reg, MCOI::OPERAND_REGISTER ))
2203+ if (!getRegSeqInit (Defs, Reg))
21412204 return false ;
21422205
21432206 for (auto &[Op, SubIdx] : Defs) {
0 commit comments