@@ -227,12 +227,12 @@ class SIFoldOperandsImpl {
227227 getRegSeqInit (SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
228228 Register UseReg) const ;
229229
230- std::pair<MachineOperand * , const TargetRegisterClass *>
230+ std::pair<int64_t , const TargetRegisterClass *>
231231 isRegSeqSplat (MachineInstr &RegSeg) const ;
232232
233- MachineOperand * tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
234- MachineOperand * SplatVal,
235- const TargetRegisterClass *SplatRC) const ;
233+ bool tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
234+ int64_t SplatVal,
235+ const TargetRegisterClass *SplatRC) const ;
236236
237237 bool tryToFoldACImm (const FoldableDef &OpToFold, MachineInstr *UseMI,
238238 unsigned UseOpIdx,
@@ -967,15 +967,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
967967 return getRegSeqInit (*Def, Defs);
968968}
969969
970- std::pair<MachineOperand * , const TargetRegisterClass *>
970+ std::pair<int64_t , const TargetRegisterClass *>
971971SIFoldOperandsImpl::isRegSeqSplat (MachineInstr &RegSeq) const {
972972 SmallVector<std::pair<MachineOperand *, unsigned >, 32 > Defs;
973973 const TargetRegisterClass *SrcRC = getRegSeqInit (RegSeq, Defs);
974974 if (!SrcRC)
975975 return {};
976976
977- // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
978- // every other other element is 0 for 64-bit immediates)
977+ bool TryToMatchSplat64 = false ;
978+
979979 int64_t Imm;
980980 for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
981981 const MachineOperand *Op = Defs[I].first ;
@@ -987,38 +987,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
987987 Imm = SubImm;
988988 continue ;
989989 }
990- if (Imm != SubImm)
990+
991+ if (Imm != SubImm) {
992+ if (I == 1 && (E & 1 ) == 0 ) {
993+ // If we have an even number of inputs, there's a chance this is a
994+ // 64-bit element splat broken into 32-bit pieces.
995+ TryToMatchSplat64 = true ;
996+ break ;
997+ }
998+
991999 return {}; // Can only fold splat constants
1000+ }
1001+ }
1002+
1003+ if (!TryToMatchSplat64)
1004+ return {Defs[0 ].first ->getImm (), SrcRC};
1005+
1006+ // Fallback to recognizing 64-bit splats broken into 32-bit pieces
1007+ // (i.e. recognize every other other element is 0 for 64-bit immediates)
1008+ int64_t SplatVal64;
1009+ for (unsigned I = 0 , E = Defs.size (); I != E; I += 2 ) {
1010+ const MachineOperand *Op0 = Defs[I].first ;
1011+ const MachineOperand *Op1 = Defs[I + 1 ].first ;
1012+
1013+ if (!Op0->isImm () || !Op1->isImm ())
1014+ return {};
1015+
1016+ unsigned SubReg0 = Defs[I].second ;
1017+ unsigned SubReg1 = Defs[I + 1 ].second ;
1018+
1019+ // Assume we're going to generally encounter reg_sequences with sorted
1020+ // subreg indexes, so reject any that aren't consecutive.
1021+ if (TRI->getChannelFromSubReg (SubReg0) + 1 !=
1022+ TRI->getChannelFromSubReg (SubReg1))
1023+ return {};
1024+
1025+ int64_t MergedVal = Make_64 (Op1->getImm (), Op0->getImm ());
1026+ if (I == 0 )
1027+ SplatVal64 = MergedVal;
1028+ else if (SplatVal64 != MergedVal)
1029+ return {};
9921030 }
9931031
994- return {Defs[0 ].first , SrcRC};
1032+ const TargetRegisterClass *RC64 = TRI->getSubRegisterClass (
1033+ MRI->getRegClass (RegSeq.getOperand (0 ).getReg ()), AMDGPU::sub0_sub1);
1034+
1035+ return {SplatVal64, RC64};
9951036}
9961037
997- MachineOperand * SIFoldOperandsImpl::tryFoldRegSeqSplat (
998- MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand * SplatVal,
1038+ bool SIFoldOperandsImpl::tryFoldRegSeqSplat (
1039+ MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
9991040 const TargetRegisterClass *SplatRC) const {
10001041 const MCInstrDesc &Desc = UseMI->getDesc ();
10011042 if (UseOpIdx >= Desc.getNumOperands ())
1002- return nullptr ;
1043+ return false ;
10031044
10041045 // Filter out unhandled pseudos.
10051046 if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
1006- return nullptr ;
1047+ return false ;
10071048
10081049 int16_t RCID = Desc.operands ()[UseOpIdx].RegClass ;
10091050 if (RCID == -1 )
1010- return nullptr ;
1051+ return false ;
1052+
1053+ const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
10111054
10121055 // Special case 0/-1, since when interpreted as a 64-bit element both halves
1013- // have the same bits. Effectively this code does not handle 64-bit element
1014- // operands correctly, as the incoming 64-bit constants are already split into
1015- // 32-bit sequence elements.
1016- //
1017- // TODO: We should try to figure out how to interpret the reg_sequence as a
1018- // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
1019- // constants.
1020- if (SplatVal->getImm () != 0 && SplatVal->getImm () != -1 ) {
1021- const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
1056+ // have the same bits. These are the only cases where a splat has the same
1057+ // interpretation for 32-bit and 64-bit splats.
1058+ if (SplatVal != 0 && SplatVal != -1 ) {
10221059 // We need to figure out the scalar type read by the operand. e.g. the MFMA
10231060 // operand will be AReg_128, and we want to check if it's compatible with an
10241061 // AReg_32 constant.
@@ -1032,17 +1069,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
10321069 OpRC = TRI->getSubRegisterClass (OpRC, AMDGPU::sub0_sub1);
10331070 break ;
10341071 default :
1035- return nullptr ;
1072+ return false ;
10361073 }
10371074
10381075 if (!TRI->getCommonSubClass (OpRC, SplatRC))
1039- return nullptr ;
1076+ return false ;
10401077 }
10411078
1042- if (!TII->isOperandLegal (*UseMI, UseOpIdx, SplatVal))
1043- return nullptr ;
1079+ MachineOperand TmpOp = MachineOperand::CreateImm (SplatVal);
1080+ if (!TII->isOperandLegal (*UseMI, UseOpIdx, &TmpOp))
1081+ return false ;
10441082
1045- return SplatVal ;
1083+ return true ;
10461084}
10471085
10481086bool SIFoldOperandsImpl::tryToFoldACImm (
@@ -1120,7 +1158,7 @@ void SIFoldOperandsImpl::foldOperand(
11201158 Register RegSeqDstReg = UseMI->getOperand (0 ).getReg ();
11211159 unsigned RegSeqDstSubReg = UseMI->getOperand (UseOpIdx + 1 ).getImm ();
11221160
1123- MachineOperand * SplatVal;
1161+ int64_t SplatVal;
11241162 const TargetRegisterClass *SplatRC;
11251163 std::tie (SplatVal, SplatRC) = isRegSeqSplat (*UseMI);
11261164
@@ -1131,10 +1169,9 @@ void SIFoldOperandsImpl::foldOperand(
11311169 MachineInstr *RSUseMI = RSUse->getParent ();
11321170 unsigned OpNo = RSUseMI->getOperandNo (RSUse);
11331171
1134- if (SplatVal) {
1135- if (MachineOperand *Foldable =
1136- tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1137- FoldableDef SplatDef (*Foldable, SplatRC);
1172+ if (SplatRC) {
1173+ if (tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1174+ FoldableDef SplatDef (SplatVal, SplatRC);
11381175 appendFoldCandidate (FoldList, RSUseMI, OpNo, SplatDef);
11391176 continue ;
11401177 }
0 commit comments