@@ -227,12 +227,12 @@ class SIFoldOperandsImpl {
227227 getRegSeqInit (SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
228228 Register UseReg) const ;
229229
230- std::pair<MachineOperand * , const TargetRegisterClass *>
230+ std::pair<int64_t , const TargetRegisterClass *>
231231 isRegSeqSplat (MachineInstr &RegSeg) const ;
232232
233- MachineOperand * tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
234- MachineOperand * SplatVal,
235- const TargetRegisterClass *SplatRC) const ;
233+ bool tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
234+ int64_t SplatVal,
235+ const TargetRegisterClass *SplatRC) const ;
236236
237237 bool tryToFoldACImm (const FoldableDef &OpToFold, MachineInstr *UseMI,
238238 unsigned UseOpIdx,
@@ -966,15 +966,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
966966 return getRegSeqInit (*Def, Defs);
967967}
968968
969- std::pair<MachineOperand * , const TargetRegisterClass *>
969+ std::pair<int64_t , const TargetRegisterClass *>
970970SIFoldOperandsImpl::isRegSeqSplat (MachineInstr &RegSeq) const {
971971 SmallVector<std::pair<MachineOperand *, unsigned >, 32 > Defs;
972972 const TargetRegisterClass *SrcRC = getRegSeqInit (RegSeq, Defs);
973973 if (!SrcRC)
974974 return {};
975975
976- // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
977- // every other other element is 0 for 64-bit immediates)
976+ bool TryToMatchSplat64 = false ;
977+
978978 int64_t Imm;
979979 for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
980980 const MachineOperand *Op = Defs[I].first ;
@@ -986,38 +986,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
986986 Imm = SubImm;
987987 continue ;
988988 }
989- if (Imm != SubImm)
989+
990+ if (Imm != SubImm) {
991+ if (I == 1 && (E & 1 ) == 0 ) {
992+ // If we have an even number of inputs, there's a chance this is a
993+ // 64-bit element splat broken into 32-bit pieces.
994+ TryToMatchSplat64 = true ;
995+ break ;
996+ }
997+
990998 return {}; // Can only fold splat constants
999+ }
1000+ }
1001+
1002+ if (!TryToMatchSplat64)
1003+ return {Defs[0 ].first ->getImm (), SrcRC};
1004+
1005+ // Fallback to recognizing 64-bit splats broken into 32-bit pieces
1006+ // (i.e. recognize every other other element is 0 for 64-bit immediates)
1007+ int64_t SplatVal64;
1008+ for (unsigned I = 0 , E = Defs.size (); I != E; I += 2 ) {
1009+ const MachineOperand *Op0 = Defs[I].first ;
1010+ const MachineOperand *Op1 = Defs[I + 1 ].first ;
1011+
1012+ if (!Op0->isImm () || !Op1->isImm ())
1013+ return {};
1014+
1015+ unsigned SubReg0 = Defs[I].second ;
1016+ unsigned SubReg1 = Defs[I + 1 ].second ;
1017+
1018+ // Assume we're going to generally encounter reg_sequences with sorted
1019+ // subreg indexes, so reject any that aren't consecutive.
1020+ if (TRI->getChannelFromSubReg (SubReg0) + 1 !=
1021+ TRI->getChannelFromSubReg (SubReg1))
1022+ return {};
1023+
1024+ int64_t MergedVal = Make_64 (Op1->getImm (), Op0->getImm ());
1025+ if (I == 0 )
1026+ SplatVal64 = MergedVal;
1027+ else if (SplatVal64 != MergedVal)
1028+ return {};
9911029 }
9921030
993- return {Defs[0 ].first , SrcRC};
1031+ const TargetRegisterClass *RC64 = TRI->getSubRegisterClass (
1032+ MRI->getRegClass (RegSeq.getOperand (0 ).getReg ()), AMDGPU::sub0_sub1);
1033+
1034+ return {SplatVal64, RC64};
9941035}
9951036
996- MachineOperand * SIFoldOperandsImpl::tryFoldRegSeqSplat (
997- MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand * SplatVal,
1037+ bool SIFoldOperandsImpl::tryFoldRegSeqSplat (
1038+ MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
9981039 const TargetRegisterClass *SplatRC) const {
9991040 const MCInstrDesc &Desc = UseMI->getDesc ();
10001041 if (UseOpIdx >= Desc.getNumOperands ())
1001- return nullptr ;
1042+ return false ;
10021043
10031044 // Filter out unhandled pseudos.
10041045 if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
1005- return nullptr ;
1046+ return false ;
10061047
10071048 int16_t RCID = Desc.operands ()[UseOpIdx].RegClass ;
10081049 if (RCID == -1 )
1009- return nullptr ;
1050+ return false ;
1051+
1052+ const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
10101053
10111054 // Special case 0/-1, since when interpreted as a 64-bit element both halves
1012- // have the same bits. Effectively this code does not handle 64-bit element
1013- // operands correctly, as the incoming 64-bit constants are already split into
1014- // 32-bit sequence elements.
1015- //
1016- // TODO: We should try to figure out how to interpret the reg_sequence as a
1017- // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
1018- // constants.
1019- if (SplatVal->getImm () != 0 && SplatVal->getImm () != -1 ) {
1020- const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
1055+ // have the same bits. These are the only cases where a splat has the same
1056+ // interpretation for 32-bit and 64-bit splats.
1057+ if (SplatVal != 0 && SplatVal != -1 ) {
10211058 // We need to figure out the scalar type read by the operand. e.g. the MFMA
10221059 // operand will be AReg_128, and we want to check if it's compatible with an
10231060 // AReg_32 constant.
@@ -1031,17 +1068,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
10311068 OpRC = TRI->getSubRegisterClass (OpRC, AMDGPU::sub0_sub1);
10321069 break ;
10331070 default :
1034- return nullptr ;
1071+ return false ;
10351072 }
10361073
10371074 if (!TRI->getCommonSubClass (OpRC, SplatRC))
1038- return nullptr ;
1075+ return false ;
10391076 }
10401077
1041- if (!TII->isOperandLegal (*UseMI, UseOpIdx, SplatVal))
1042- return nullptr ;
1078+ MachineOperand TmpOp = MachineOperand::CreateImm (SplatVal);
1079+ if (!TII->isOperandLegal (*UseMI, UseOpIdx, &TmpOp))
1080+ return false ;
10431081
1044- return SplatVal ;
1082+ return true ;
10451083}
10461084
10471085bool SIFoldOperandsImpl::tryToFoldACImm (
@@ -1119,7 +1157,7 @@ void SIFoldOperandsImpl::foldOperand(
11191157 Register RegSeqDstReg = UseMI->getOperand (0 ).getReg ();
11201158 unsigned RegSeqDstSubReg = UseMI->getOperand (UseOpIdx + 1 ).getImm ();
11211159
1122- MachineOperand * SplatVal;
1160+ int64_t SplatVal;
11231161 const TargetRegisterClass *SplatRC;
11241162 std::tie (SplatVal, SplatRC) = isRegSeqSplat (*UseMI);
11251163
@@ -1130,10 +1168,9 @@ void SIFoldOperandsImpl::foldOperand(
11301168 MachineInstr *RSUseMI = RSUse->getParent ();
11311169 unsigned OpNo = RSUseMI->getOperandNo (RSUse);
11321170
1133- if (SplatVal) {
1134- if (MachineOperand *Foldable =
1135- tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1136- FoldableDef SplatDef (*Foldable, SplatRC);
1171+ if (SplatRC) {
1172+ if (tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1173+ FoldableDef SplatDef (SplatVal, SplatRC);
11371174 appendFoldCandidate (FoldList, RSUseMI, OpNo, SplatDef);
11381175 continue ;
11391176 }
0 commit comments