@@ -226,12 +226,12 @@ class SIFoldOperandsImpl {
226226 getRegSeqInit (SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
227227 Register UseReg) const ;
228228
229- std::pair<MachineOperand * , const TargetRegisterClass *>
229+ std::pair<int64_t , const TargetRegisterClass *>
230230 isRegSeqSplat (MachineInstr &RegSeg) const ;
231231
232- MachineOperand * tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
233- MachineOperand * SplatVal,
234- const TargetRegisterClass *SplatRC) const ;
232+ bool tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
233+ int64_t SplatVal,
234+ const TargetRegisterClass *SplatRC) const ;
235235
236236 bool tryToFoldACImm (const FoldableDef &OpToFold, MachineInstr *UseMI,
237237 unsigned UseOpIdx,
@@ -964,15 +964,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
964964 return getRegSeqInit (*Def, Defs);
965965}
966966
967- std::pair<MachineOperand * , const TargetRegisterClass *>
967+ std::pair<int64_t , const TargetRegisterClass *>
968968SIFoldOperandsImpl::isRegSeqSplat (MachineInstr &RegSeq) const {
969969 SmallVector<std::pair<MachineOperand *, unsigned >, 32 > Defs;
970970 const TargetRegisterClass *SrcRC = getRegSeqInit (RegSeq, Defs);
971971 if (!SrcRC)
972972 return {};
973973
974- // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
975- // every other other element is 0 for 64-bit immediates)
974+ bool TryToMatchSplat64 = false ;
975+
976976 int64_t Imm;
977977 for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
978978 const MachineOperand *Op = Defs[I].first ;
@@ -984,38 +984,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
984984 Imm = SubImm;
985985 continue ;
986986 }
987- if (Imm != SubImm)
987+
988+ if (Imm != SubImm) {
989+ if (I == 1 && (E & 1 ) == 0 ) {
990+ // If we have an even number of inputs, there's a chance this is a
991+ // 64-bit element splat broken into 32-bit pieces.
992+ TryToMatchSplat64 = true ;
993+ break ;
994+ }
995+
988996 return {}; // Can only fold splat constants
997+ }
998+ }
999+
1000+ if (!TryToMatchSplat64)
1001+ return {Defs[0 ].first ->getImm (), SrcRC};
1002+
1003+ // Fallback to recognizing 64-bit splats broken into 32-bit pieces
1004+ // (i.e. recognize every other other element is 0 for 64-bit immediates)
1005+ int64_t SplatVal64;
1006+ for (unsigned I = 0 , E = Defs.size (); I != E; I += 2 ) {
1007+ const MachineOperand *Op0 = Defs[I].first ;
1008+ const MachineOperand *Op1 = Defs[I + 1 ].first ;
1009+
1010+ if (!Op0->isImm () || !Op1->isImm ())
1011+ return {};
1012+
1013+ unsigned SubReg0 = Defs[I].second ;
1014+ unsigned SubReg1 = Defs[I + 1 ].second ;
1015+
1016+ // Assume we're going to generally encounter reg_sequences with sorted
1017+ // subreg indexes, so reject any that aren't consecutive.
1018+ if (TRI->getChannelFromSubReg (SubReg0) + 1 !=
1019+ TRI->getChannelFromSubReg (SubReg1))
1020+ return {};
1021+
1022+ int64_t MergedVal = Make_64 (Op1->getImm (), Op0->getImm ());
1023+ if (I == 0 )
1024+ SplatVal64 = MergedVal;
1025+ else if (SplatVal64 != MergedVal)
1026+ return {};
9891027 }
9901028
991- return {Defs[0 ].first , SrcRC};
1029+ const TargetRegisterClass *RC64 = TRI->getSubRegisterClass (
1030+ MRI->getRegClass (RegSeq.getOperand (0 ).getReg ()), AMDGPU::sub0_sub1);
1031+
1032+ return {SplatVal64, RC64};
9921033}
9931034
994- MachineOperand * SIFoldOperandsImpl::tryFoldRegSeqSplat (
995- MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand * SplatVal,
1035+ bool SIFoldOperandsImpl::tryFoldRegSeqSplat (
1036+ MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
9961037 const TargetRegisterClass *SplatRC) const {
9971038 const MCInstrDesc &Desc = UseMI->getDesc ();
9981039 if (UseOpIdx >= Desc.getNumOperands ())
999- return nullptr ;
1040+ return false ;
10001041
10011042 // Filter out unhandled pseudos.
10021043 if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
1003- return nullptr ;
1044+ return false ;
10041045
10051046 int16_t RCID = Desc.operands ()[UseOpIdx].RegClass ;
10061047 if (RCID == -1 )
1007- return nullptr ;
1048+ return false ;
1049+
1050+ const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
10081051
10091052 // Special case 0/-1, since when interpreted as a 64-bit element both halves
1010- // have the same bits. Effectively this code does not handle 64-bit element
1011- // operands correctly, as the incoming 64-bit constants are already split into
1012- // 32-bit sequence elements.
1013- //
1014- // TODO: We should try to figure out how to interpret the reg_sequence as a
1015- // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
1016- // constants.
1017- if (SplatVal->getImm () != 0 && SplatVal->getImm () != -1 ) {
1018- const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
1053+ // have the same bits. These are the only cases where a splat has the same
1054+ // interpretation for 32-bit and 64-bit splats.
1055+ if (SplatVal != 0 && SplatVal != -1 ) {
10191056 // We need to figure out the scalar type read by the operand. e.g. the MFMA
10201057 // operand will be AReg_128, and we want to check if it's compatible with an
10211058 // AReg_32 constant.
@@ -1029,17 +1066,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
10291066 OpRC = TRI->getSubRegisterClass (OpRC, AMDGPU::sub0_sub1);
10301067 break ;
10311068 default :
1032- return nullptr ;
1069+ return false ;
10331070 }
10341071
10351072 if (!TRI->getCommonSubClass (OpRC, SplatRC))
1036- return nullptr ;
1073+ return false ;
10371074 }
10381075
1039- if (!TII->isOperandLegal (*UseMI, UseOpIdx, SplatVal))
1040- return nullptr ;
1076+ MachineOperand TmpOp = MachineOperand::CreateImm (SplatVal);
1077+ if (!TII->isOperandLegal (*UseMI, UseOpIdx, &TmpOp))
1078+ return false ;
10411079
1042- return SplatVal ;
1080+ return true ;
10431081}
10441082
10451083bool SIFoldOperandsImpl::tryToFoldACImm (
@@ -1117,7 +1155,7 @@ void SIFoldOperandsImpl::foldOperand(
11171155 Register RegSeqDstReg = UseMI->getOperand (0 ).getReg ();
11181156 unsigned RegSeqDstSubReg = UseMI->getOperand (UseOpIdx + 1 ).getImm ();
11191157
1120- MachineOperand * SplatVal;
1158+ int64_t SplatVal;
11211159 const TargetRegisterClass *SplatRC;
11221160 std::tie (SplatVal, SplatRC) = isRegSeqSplat (*UseMI);
11231161
@@ -1128,10 +1166,9 @@ void SIFoldOperandsImpl::foldOperand(
11281166 MachineInstr *RSUseMI = RSUse->getParent ();
11291167 unsigned OpNo = RSUseMI->getOperandNo (RSUse);
11301168
1131- if (SplatVal) {
1132- if (MachineOperand *Foldable =
1133- tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1134- FoldableDef SplatDef (*Foldable, SplatRC);
1169+ if (SplatRC) {
1170+ if (tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1171+ FoldableDef SplatDef (SplatVal, SplatRC);
11351172 appendFoldCandidate (FoldList, RSUseMI, OpNo, SplatDef);
11361173 continue ;
11371174 }
0 commit comments