Skip to content

Commit b52cf75

Browse files
changpengsstipano
andauthored
AMDGPU: Treat WMMA XDL ops as TRANS in S_DELAY_ALU insertion for gfx1250 (#149208)
WMMA XDL instructions are tracked as TRANs ops and the compiler should consider them the same as TRANS in S_DELAY_ALU insertion. We use a searchable table for the InsertDelayAlu pass to recognize these WMMA XDL instructions. Co-authored-by: Stefan Stipanovic <[email protected]>
1 parent 9d78eb5 commit b52cf75

File tree

6 files changed

+129
-11
lines changed

6 files changed

+129
-11
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInsertDelayAlu.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace {
2525

2626
class AMDGPUInsertDelayAlu {
2727
public:
28+
const GCNSubtarget *ST;
2829
const SIInstrInfo *SII;
2930
const TargetRegisterInfo *TRI;
3031

@@ -65,13 +66,16 @@ class AMDGPUInsertDelayAlu {
6566
// Types of delay that can be encoded in an s_delay_alu instruction.
6667
enum DelayType { VALU, TRANS, SALU, OTHER };
6768

68-
// Get the delay type for an instruction with the specified TSFlags.
69-
static DelayType getDelayType(uint64_t TSFlags) {
70-
if (TSFlags & SIInstrFlags::TRANS)
69+
// Get the delay type for a MachineInstr.
70+
DelayType getDelayType(const MachineInstr &MI) {
71+
if (SIInstrInfo::isTRANS(MI))
7172
return TRANS;
72-
if (TSFlags & SIInstrFlags::VALU)
73+
// WMMA XDL ops are treated the same as TRANS.
74+
if (AMDGPU::isGFX1250(*ST) && SII->isXDLWMMA(MI))
75+
return TRANS;
76+
if (SIInstrInfo::isVALU(MI))
7377
return VALU;
74-
if (TSFlags & SIInstrFlags::SALU)
78+
if (SIInstrInfo::isSALU(MI))
7579
return SALU;
7680
return OTHER;
7781
}
@@ -368,7 +372,7 @@ class AMDGPUInsertDelayAlu {
368372
continue;
369373
}
370374

371-
DelayType Type = getDelayType(MI.getDesc().TSFlags);
375+
DelayType Type = getDelayType(MI);
372376

373377
if (instructionWaitsForSGPRWrites(MI)) {
374378
auto It = State.find(LastSGPRFromVALU);
@@ -456,12 +460,12 @@ class AMDGPUInsertDelayAlu {
456460
LLVM_DEBUG(dbgs() << "AMDGPUInsertDelayAlu running on " << MF.getName()
457461
<< "\n");
458462

459-
const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
460-
if (!ST.hasDelayAlu())
463+
ST = &MF.getSubtarget<GCNSubtarget>();
464+
if (!ST->hasDelayAlu())
461465
return false;
462466

463-
SII = ST.getInstrInfo();
464-
TRI = ST.getRegisterInfo();
467+
SII = ST->getInstrInfo();
468+
TRI = ST->getRegisterInfo();
465469
SchedModel = &SII->getSchedModel();
466470

467471
// Calculate the delay state for each basic block, iterating until we reach

llvm/lib/Target/AMDGPU/SIInstrInfo.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10466,10 +10466,23 @@ bool SIInstrInfo::isGlobalMemoryObject(const MachineInstr *MI) const {
1046610466
return TargetInstrInfo::isGlobalMemoryObject(MI);
1046710467
}
1046810468

10469+
bool SIInstrInfo::isXDLWMMA(const MachineInstr &MI) const {
10470+
if (!isWMMA(MI) && !isSWMMAC(MI))
10471+
return false;
10472+
10473+
if (AMDGPU::isGFX1250(ST))
10474+
return AMDGPU::getWMMAIsXDL(MI.getOpcode());
10475+
10476+
return true;
10477+
}
10478+
1046910479
bool SIInstrInfo::isXDL(const MachineInstr &MI) const {
1047010480
unsigned Opcode = MI.getOpcode();
1047110481

10472-
if (!SIInstrInfo::isMAI(MI) || isDGEMM(Opcode) ||
10482+
if (AMDGPU::isGFX12Plus(ST))
10483+
return isDOT(MI) || isXDLWMMA(MI);
10484+
10485+
if (!isMAI(MI) || isDGEMM(Opcode) ||
1047310486
Opcode == AMDGPU::V_ACCVGPR_WRITE_B32_e64 ||
1047410487
Opcode == AMDGPU::V_ACCVGPR_READ_B32_e64)
1047510488
return false;

llvm/lib/Target/AMDGPU/SIInstrInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,8 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {
867867
return get(Opcode).TSFlags & SIInstrFlags::IsDOT;
868868
}
869869

870+
bool isXDLWMMA(const MachineInstr &MI) const;
871+
870872
bool isXDL(const MachineInstr &MI) const;
871873

872874
static bool isDGEMM(unsigned Opcode) { return AMDGPU::getMAIIsDGEMM(Opcode); }

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ unsigned getCompletionActionImplicitArgPosition(unsigned CodeObjectVersion) {
296296
#define GET_MIMGOffsetMappingTable_IMPL
297297
#define GET_MIMGG16MappingTable_IMPL
298298
#define GET_MAIInstInfoTable_IMPL
299+
#define GET_WMMAInstInfoTable_IMPL
299300
#include "AMDGPUGenSearchableTables.inc"
300301

301302
int getMIMGOpcode(unsigned BaseOpcode, unsigned MIMGEncoding,
@@ -568,6 +569,11 @@ bool getMAIIsGFX940XDL(unsigned Opc) {
568569
return Info && Info->is_gfx940_xdl;
569570
}
570571

572+
bool getWMMAIsXDL(unsigned Opc) {
573+
const WMMAInstInfo *Info = getWMMAInstInfoHelper(Opc);
574+
return Info ? Info->is_wmma_xdl : false;
575+
}
576+
571577
uint8_t mfmaScaleF8F6F4FormatToNumRegs(unsigned EncodingVal) {
572578
switch (EncodingVal) {
573579
case MFMAScaleFormats::FP6_E2M3:

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ struct True16D16Info {
119119
unsigned LoOp;
120120
};
121121

122+
struct WMMAInstInfo {
123+
uint16_t Opcode;
124+
bool is_wmma_xdl;
125+
};
126+
122127
#define GET_MIMGBaseOpcode_DECL
123128
#define GET_MIMGDim_DECL
124129
#define GET_MIMGEncoding_DECL
@@ -129,6 +134,7 @@ struct True16D16Info {
129134
#define GET_isMFMA_F8F6F4Table_DECL
130135
#define GET_isCvtScaleF32_F32F16ToF8F4Table_DECL
131136
#define GET_True16D16Table_DECL
137+
#define GET_WMMAInstInfoTable_DECL
132138
#include "AMDGPUGenSearchableTables.inc"
133139

134140
namespace IsaInfo {
@@ -593,6 +599,9 @@ bool getMAIIsDGEMM(unsigned Opc);
593599
LLVM_READONLY
594600
bool getMAIIsGFX940XDL(unsigned Opc);
595601

602+
LLVM_READONLY
603+
bool getWMMAIsXDL(unsigned Opc);
604+
596605
// Get an equivalent BitOp3 for a binary logical \p Opc.
597606
// \returns BitOp3 modifier for the logical operation or zero.
598607
// Used in VOPD3 conversion.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -start-before=amdgpu-insert-delay-alu %s -o - | FileCheck %s
2+
3+
---
4+
name: wmma_xdl_twoaddr_trans
5+
tracksRegLiveness: true
6+
body: |
7+
bb.0:
8+
; CHECK-LABEL: {{^}}wmma_xdl_twoaddr_trans:
9+
; CHECK: %bb.0:
10+
; CHECK-NEXT: v_wmma_f32_16x16x64_fp8_fp8 v[8:15], v[0:7], v[0:7], v[8:15]
11+
; CHECK-NEXT: v_exp_f32_e32 v16, v16
12+
; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2)
13+
; CHECK-NEXT: v_add_nc_u32_e32 v17, v17, v8
14+
liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, $vgpr16, $vgpr17
15+
$vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = V_WMMA_F32_16X16X64_FP8_FP8_w32_twoaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, 8, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, 0, 0, 0, 0, implicit $exec
16+
$vgpr16 = V_EXP_F32_e32 $vgpr16, implicit $exec, implicit $mode
17+
$vgpr17 = V_ADD_U32_e32 $vgpr17, $vgpr8, implicit $exec
18+
...
19+
20+
---
21+
name: wmma_xdl_threeaddr_trans
22+
tracksRegLiveness: true
23+
body: |
24+
bb.0:
25+
; CHECK-LABEL: {{^}}wmma_xdl_threeaddr_trans:
26+
; CHECK: %bb.0:
27+
; CHECK-NEXT: v_wmma_f32_16x16x64_fp8_fp8 v[8:15], v[0:7], v[0:7], v[16:23]
28+
; CHECK-NEXT: v_exp_f32_e32 v24, v24
29+
; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2)
30+
; CHECK-NEXT: v_add_nc_u32_e32 v25, v25, v8
31+
liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24, $vgpr25
32+
$vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = V_WMMA_F32_16X16X64_FP8_FP8_w32_threeaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, 8, $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, 0, 0, 0, 0, implicit $exec
33+
$vgpr24 = V_EXP_F32_e32 $vgpr24, implicit $exec, implicit $mode
34+
$vgpr25 = V_ADD_U32_e32 $vgpr25, $vgpr8, implicit $exec
35+
...
36+
37+
name: swmmac_xdl_twoaddr_trans
38+
tracksRegLiveness: true
39+
body: |
40+
bb.0:
41+
; CHECK-LABEL: {{^}}swmmac_xdl_twoaddr_trans:
42+
; CHECK: %bb.0:
43+
; CHECK-NEXT: v_swmmac_f16_16x16x128_bf8_bf8 v[24:27], v[0:7], v[8:23], v[28:29]
44+
; CHECK-NEXT: v_exp_f32_e32 v30, v30
45+
; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2)
46+
; CHECK-NEXT: v_add_nc_u32_e32 v31, v31, v24
47+
liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24_vgpr25_vgpr26_vgpr27, $vgpr28, $vgpr29, $vgpr30, $vgpr31
48+
$vgpr24_vgpr25_vgpr26_vgpr27 = V_SWMMAC_F16_16X16X128_BF8_BF8_w32_twoaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24_vgpr25_vgpr26_vgpr27, $vgpr28_vgpr29, 0, 0, 0, implicit $exec
49+
$vgpr30 = V_EXP_F32_e32 $vgpr30, implicit $exec, implicit $mode
50+
$vgpr31 = V_ADD_U32_e32 $vgpr31, $vgpr24, implicit $exec
51+
...
52+
53+
name: wmma_non_xdl_large_data_valu
54+
tracksRegLiveness: true
55+
body: |
56+
bb.0:
57+
; CHECK-LABEL: {{^}}wmma_non_xdl_large_data_valu:
58+
; CHECK: %bb.0:
59+
; CHECK-NEXT: v_wmma_f32_16x16x4_f32 v[4:11], v[0:1], v[2:3], v[4:11] matrix_b_reuse
60+
; CHECK-NEXT: v_exp_f32_e32 v12, v12
61+
; CHECK-NEXT: s_delay_alu instid0(VALU_DEP_1)
62+
; CHECK-NEXT: v_add_nc_u32_e32 v13, v13, v8
63+
liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4_vgpr5_vgpr6_vgpr7, $vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11, $vgpr12, $vgpr13
64+
$vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11 = V_WMMA_F32_16X16X4_F32_w32_twoaddr 8, $vgpr0_vgpr1, 8, $vgpr2_vgpr3, 8, $vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11, 0, -1, 0, 0, implicit $exec
65+
$vgpr12 = V_EXP_F32_e32 $vgpr12, implicit $exec, implicit $mode
66+
$vgpr13 = V_ADD_U32_e32 $vgpr13, $vgpr8, implicit $exec
67+
...
68+
69+
---
70+
name: dot_xdl_dep_2
71+
tracksRegLiveness: true
72+
body: |
73+
bb.0:
74+
; CHECK-LABEL: {{^}}dot_xdl_dep_2:
75+
; CHECK: %bb.0:
76+
; CHECK-NEXT: v_dot4_i32_iu8 v0, s2, s3, v0 neg_lo:[1,1,0]
77+
; CHECK-NEXT: v_dot4_i32_iu8 v1, s2, s3, v2 neg_lo:[1,1,0]
78+
; CHECK-NEXT: s_delay_alu instid0(VALU_DEP_2)
79+
; CHECK-NEXT: v_add_nc_u32_e32 v2, v0, v0
80+
liveins: $vgpr0, $sgpr2, $sgpr3, $vgpr0, $vgpr1, $vgpr2
81+
$vgpr0 = V_DOT4_I32_IU8 9, $sgpr2, 9, $sgpr3, 8, $vgpr0, 0, 0, 0, implicit $exec
82+
$vgpr1 = V_DOT4_I32_IU8 9, $sgpr2, 9, $sgpr3, 8, $vgpr2, 0, 0, 0, implicit $exec
83+
$vgpr2 = V_ADD_U32_e32 $vgpr0, $vgpr0, implicit $exec
84+
...

0 commit comments

Comments
 (0)