@@ -64,6 +64,7 @@ def Write8PassMAI : SchedWrite;
6464def Write16PassMAI : SchedWrite;
6565def Write4PassDGEMM : SchedWrite;
6666def Write8PassDGEMM : SchedWrite;
67+ def Write16PassDGEMM : SchedWrite;
6768
6869// Scalar float instructions
6970def WriteSFPU : SchedWrite;
@@ -94,6 +95,7 @@ def SIFullSpeedModel : SISchedMachineModel;
9495def SIQuarterSpeedModel : SISchedMachineModel;
9596def SIDPFullSpeedModel : SISchedMachineModel;
9697def SIDPGFX940FullSpeedModel : SISchedMachineModel;
98+ def SIDPGFX950FullSpeedModel : SISchedMachineModel;
9799def GFX10SpeedModel : SISchedMachineModel;
98100def GFX11SpeedModel : SISchedMachineModel;
99101def GFX12SpeedModel : SISchedMachineModel;
@@ -169,6 +171,8 @@ multiclass SICommonWriteRes {
169171 def : HWVALUWriteRes<Write4PassDGEMM, 4>;
170172 let ReleaseAtCycles = [8] in
171173 def : HWVALUWriteRes<Write8PassDGEMM, 8>;
174+ let ReleaseAtCycles = [16] in
175+ def : HWVALUWriteRes<Write16PassDGEMM, 16>;
172176
173177 let ReleaseAtCycles = [2] in
174178 def : HWWriteRes<Write2PassMAI, [HWXDL], 2>;
@@ -201,6 +205,13 @@ def WriteCopy : SchedWriteVariant<[
201205 SchedVar<PredIsVGPR64Copy, [Write64Bit]>,
202206 SchedVar<NoSchedPred, [WriteSALU]>]>;
203207
208+ // Check if any matrix inputs are interpreted as f8 in an f8f6f4 mfma
209+ // instruction.
210+ def PredIsF8_MFMA_SCALE : SchedPredicate<[{
211+ TII->getNamedOperand(*MI, AMDGPU::OpName::cbsz)->getImm() <= AMDGPU::MFMAScaleFormats::FP8_E5M2 ||
212+ TII->getNamedOperand(*MI, AMDGPU::OpName::blgp)->getImm() <= AMDGPU::MFMAScaleFormats::FP8_E5M2
213+ }]>;
214+
204215let SchedModel = SIFullSpeedModel in {
205216
206217defm : SICommonWriteRes;
@@ -299,6 +310,58 @@ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_SMFMAC_.32_32X32X")>;
299310
300311} // End SchedModel = SIDPGFX940FullSpeedModel
301312
313+
314+ let SchedModel = SIDPGFX950FullSpeedModel in {
315+ defm : SICommonWriteRes;
316+
317+ def : HWVALUWriteRes<WriteFloatFMA, 1>;
318+ def : HWVALUWriteRes<WriteDouble, 1>;
319+ def : HWVALUWriteRes<WriteDoubleAdd, 1>;
320+ def : HWVALUWriteRes<WriteDoubleCvt, 1>;
321+ def : HWVALUWriteRes<WriteTrans64, 4>;
322+ def : HWVALUWriteRes<WriteIntMul, 1>;
323+ def : HWVALUWriteRes<Write64Bit, 1>;
324+
325+ def : InstRW<[WriteCopy], (instrs COPY)>;
326+ def : InstRW<[Write64Bit], (instregex "^V_ACCVGPR_WRITE_B32_e64$")>;
327+ def : InstRW<[Write2PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_4X4X")>;
328+
329+ def : InstRW<[Write4PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_16X16X8X")>;
330+ def : InstRW<[Write4PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_16X16X16")>;
331+ def : InstRW<[Write4PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_16X16X32")>;
332+ def : InstRW<[Write4PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_16X16X64")>;
333+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_16X16X[14][FBI]")>;
334+
335+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_32X32X4XF")>;
336+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_32X32X8")>;
337+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_32X32X16")>;
338+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_32X32X32_")>;
339+ def : InstRW<[Write16PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_32X32X[124][FBI]")>;
340+
341+ def : InstRW<[Write4PassDGEMM, MIMFMARead], (instregex "^V_MFMA_.64_4X4X")>;
342+ def : InstRW<[Write16PassDGEMM, MIMFMARead], (instregex "^V_MFMA_.64_16X16X")>;
343+
344+ def : InstRW<[Write4PassMAI, MIMFMARead], (instregex "^V_SMFMAC_.32_16X16X")>;
345+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_SMFMAC_.32_32X32X")>;
346+
347+
348+ // If either matrix format is f8, the instruction takes 2x as many
349+ // cycles. TODO: This isn't reflected in MCA.
350+ def WriteMFMAScale_16X16X128_F8F6F4 : SchedWriteVariant<[
351+ SchedVar<PredIsF8_MFMA_SCALE, [Write8PassMAI]>,
352+ SchedVar<NoSchedPred, [Write4PassMAI]>]>;
353+ def WriteMFMAScale_32X32X64_F8F6F4 : SchedWriteVariant<[
354+ SchedVar<PredIsF8_MFMA_SCALE, [Write16PassMAI]>,
355+ SchedVar<NoSchedPred, [Write8PassMAI]>]>;
356+
357+ def : InstRW<[WriteMFMAScale_16X16X128_F8F6F4, MIMFMARead],
358+ (instregex "^V_MFMA(_SCALE)?_.32_16X16X128_F8F6F4")>;
359+ def : InstRW<[WriteMFMAScale_32X32X64_F8F6F4, MIMFMARead],
360+ (instregex "^V_MFMA(_SCALE)?_.32_32X32X64_F8F6F4")>;
361+
362+ } // End SchedModel = SIDPGFX950FullSpeedModel
363+
364+
302365let SchedModel = GFX10SpeedModel in {
303366
304367// The latency values are 1 / (operations / cycle).
0 commit comments