Skip to content

Commit 8fc2a53

Browse files
committed
Reland "Revert "AMDGPU: Treat WMMA XDL ops as TRANS in S_DELAY_ALU insertion for gfx1250 (llvm#149208)""
This reverts commit 417cd79.
1 parent 49ed5b7 commit 8fc2a53

File tree

6 files changed

+48
-14
lines changed

6 files changed

+48
-14
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInsertDelayAlu.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace {
2525

2626
class AMDGPUInsertDelayAlu {
2727
public:
28+
const GCNSubtarget *ST;
2829
const SIInstrInfo *SII;
2930
const TargetRegisterInfo *TRI;
3031

@@ -65,13 +66,16 @@ class AMDGPUInsertDelayAlu {
6566
// Types of delay that can be encoded in an s_delay_alu instruction.
6667
enum DelayType { VALU, TRANS, SALU, OTHER };
6768

68-
// Get the delay type for an instruction with the specified TSFlags.
69-
static DelayType getDelayType(uint64_t TSFlags) {
70-
if (TSFlags & SIInstrFlags::TRANS)
69+
// Get the delay type for a MachineInstr.
70+
DelayType getDelayType(const MachineInstr &MI) {
71+
if (SIInstrInfo::isTRANS(MI))
7172
return TRANS;
72-
if (TSFlags & SIInstrFlags::VALU)
73+
// WMMA XDL ops are treated the same as TRANS.
74+
if (AMDGPU::isGFX1250(*ST) && SII->isXDLWMMA(MI))
75+
return TRANS;
76+
if (SIInstrInfo::isVALU(MI))
7377
return VALU;
74-
if (TSFlags & SIInstrFlags::SALU)
78+
if (SIInstrInfo::isSALU(MI))
7579
return SALU;
7680
return OTHER;
7781
}
@@ -368,7 +372,7 @@ class AMDGPUInsertDelayAlu {
368372
continue;
369373
}
370374

371-
DelayType Type = getDelayType(MI.getDesc().TSFlags);
375+
DelayType Type = getDelayType(MI);
372376

373377
if (instructionWaitsForSGPRWrites(MI)) {
374378
auto It = State.find(LastSGPRFromVALU);
@@ -456,12 +460,12 @@ class AMDGPUInsertDelayAlu {
456460
LLVM_DEBUG(dbgs() << "AMDGPUInsertDelayAlu running on " << MF.getName()
457461
<< "\n");
458462

459-
const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
460-
if (!ST.hasDelayAlu())
463+
ST = &MF.getSubtarget<GCNSubtarget>();
464+
if (!ST->hasDelayAlu())
461465
return false;
462466

463-
SII = ST.getInstrInfo();
464-
TRI = ST.getRegisterInfo();
467+
SII = ST->getInstrInfo();
468+
TRI = ST->getRegisterInfo();
465469
SchedModel = &SII->getSchedModel();
466470

467471
// Calculate the delay state for each basic block, iterating until we reach

llvm/lib/Target/AMDGPU/SIInstrInfo.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10602,10 +10602,23 @@ bool SIInstrInfo::isGlobalMemoryObject(const MachineInstr *MI) const {
1060210602
return TargetInstrInfo::isGlobalMemoryObject(MI);
1060310603
}
1060410604

10605+
bool SIInstrInfo::isXDLWMMA(const MachineInstr &MI) const {
10606+
if (!isWMMA(MI) && !isSWMMAC(MI))
10607+
return false;
10608+
10609+
if (AMDGPU::isGFX1250(ST))
10610+
return AMDGPU::getWMMAIsXDL(MI.getOpcode());
10611+
10612+
return true;
10613+
}
10614+
1060510615
bool SIInstrInfo::isXDL(const MachineInstr &MI) const {
1060610616
unsigned Opcode = MI.getOpcode();
1060710617

10608-
if (!SIInstrInfo::isMAI(MI) || isDGEMM(Opcode) ||
10618+
if (AMDGPU::isGFX12Plus(ST))
10619+
return isDOT(MI) || isXDLWMMA(MI);
10620+
10621+
if (!isMAI(MI) || isDGEMM(Opcode) ||
1060910622
Opcode == AMDGPU::V_ACCVGPR_WRITE_B32_e64 ||
1061010623
Opcode == AMDGPU::V_ACCVGPR_READ_B32_e64)
1061110624
return false;

llvm/lib/Target/AMDGPU/SIInstrInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,8 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {
899899
return get(Opcode).TSFlags & SIInstrFlags::IsDOT;
900900
}
901901

902+
bool isXDLWMMA(const MachineInstr &MI) const;
903+
902904
bool isXDL(const MachineInstr &MI) const;
903905

904906
static bool isDGEMM(unsigned Opcode) { return AMDGPU::getMAIIsDGEMM(Opcode); }

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ unsigned getCompletionActionImplicitArgPosition(unsigned CodeObjectVersion) {
296296
#define GET_MIMGOffsetMappingTable_IMPL
297297
#define GET_MIMGG16MappingTable_IMPL
298298
#define GET_MAIInstInfoTable_IMPL
299+
#define GET_WMMAInstInfoTable_IMPL
299300
#include "AMDGPUGenSearchableTables.inc"
300301

301302
int getMIMGOpcode(unsigned BaseOpcode, unsigned MIMGEncoding,
@@ -568,6 +569,11 @@ bool getMAIIsGFX940XDL(unsigned Opc) {
568569
return Info && Info->is_gfx940_xdl;
569570
}
570571

572+
bool getWMMAIsXDL(unsigned Opc) {
573+
const WMMAInstInfo *Info = getWMMAInstInfoHelper(Opc);
574+
return Info ? Info->is_wmma_xdl : false;
575+
}
576+
571577
uint8_t mfmaScaleF8F6F4FormatToNumRegs(unsigned EncodingVal) {
572578
switch (EncodingVal) {
573579
case MFMAScaleFormats::FP6_E2M3:

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ struct True16D16Info {
119119
unsigned LoOp;
120120
};
121121

122+
struct WMMAInstInfo {
123+
uint16_t Opcode;
124+
bool is_wmma_xdl;
125+
};
126+
122127
#define GET_MIMGBaseOpcode_DECL
123128
#define GET_MIMGDim_DECL
124129
#define GET_MIMGEncoding_DECL
@@ -129,6 +134,7 @@ struct True16D16Info {
129134
#define GET_isMFMA_F8F6F4Table_DECL
130135
#define GET_isCvtScaleF32_F32F16ToF8F4Table_DECL
131136
#define GET_True16D16Table_DECL
137+
#define GET_WMMAInstInfoTable_DECL
132138
#include "AMDGPUGenSearchableTables.inc"
133139

134140
namespace IsaInfo {
@@ -593,6 +599,9 @@ bool getMAIIsDGEMM(unsigned Opc);
593599
LLVM_READONLY
594600
bool getMAIIsGFX940XDL(unsigned Opc);
595601

602+
LLVM_READONLY
603+
bool getWMMAIsXDL(unsigned Opc);
604+
596605
// Get an equivalent BitOp3 for a binary logical \p Opc.
597606
// \returns BitOp3 modifier for the logical operation or zero.
598607
// Used in VOPD3 conversion.

llvm/test/CodeGen/AMDGPU/insert-delay-alu-wmma-xdl.mir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ body: |
99
; CHECK: %bb.0:
1010
; CHECK-NEXT: v_wmma_f32_16x16x64_fp8_fp8 v[8:15], v[0:7], v[0:7], v[8:15]
1111
; CHECK-NEXT: v_exp_f32_e32 v16, v16
12-
; CHECK-NEXT: s_delay_alu instid0(VALU_DEP_1)
12+
; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2)
1313
; CHECK-NEXT: v_add_nc_u32_e32 v17, v17, v8
1414
liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, $vgpr16, $vgpr17
1515
$vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = V_WMMA_F32_16X16X64_FP8_FP8_w32_twoaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, 8, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, 0, 0, 0, 0, implicit $exec
@@ -26,7 +26,7 @@ body: |
2626
; CHECK: %bb.0:
2727
; CHECK-NEXT: v_wmma_f32_16x16x64_fp8_fp8 v[8:15], v[0:7], v[0:7], v[16:23]
2828
; CHECK-NEXT: v_exp_f32_e32 v24, v24
29-
; CHECK-NEXT: s_delay_alu instid0(VALU_DEP_1)
29+
; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2)
3030
; CHECK-NEXT: v_add_nc_u32_e32 v25, v25, v8
3131
liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24, $vgpr25
3232
$vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = V_WMMA_F32_16X16X64_FP8_FP8_w32_threeaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, 8, $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, 0, 0, 0, 0, implicit $exec
@@ -42,7 +42,7 @@ body: |
4242
; CHECK: %bb.0:
4343
; CHECK-NEXT: v_swmmac_f16_16x16x128_bf8_bf8 v[24:27], v[0:7], v[8:23], v[28:29]
4444
; CHECK-NEXT: v_exp_f32_e32 v30, v30
45-
; CHECK-NEXT: s_delay_alu instid0(VALU_DEP_1)
45+
; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2)
4646
; CHECK-NEXT: v_add_nc_u32_e32 v31, v31, v24
4747
liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24_vgpr25_vgpr26_vgpr27, $vgpr28, $vgpr29, $vgpr30, $vgpr31
4848
$vgpr24_vgpr25_vgpr26_vgpr27 = V_SWMMAC_F16_16X16X128_BF8_BF8_w32_twoaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24_vgpr25_vgpr26_vgpr27, $vgpr28_vgpr29, 0, 0, 0, implicit $exec

0 commit comments

Comments
 (0)