Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions llvm/lib/Target/AMDGPU/AMDGPUInsertDelayAlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace {

class AMDGPUInsertDelayAlu {
public:
const GCNSubtarget *ST;
const SIInstrInfo *SII;
const TargetRegisterInfo *TRI;

Expand Down Expand Up @@ -65,13 +66,16 @@ class AMDGPUInsertDelayAlu {
// Types of delay that can be encoded in an s_delay_alu instruction.
enum DelayType { VALU, TRANS, SALU, OTHER };

// Get the delay type for an instruction with the specified TSFlags.
static DelayType getDelayType(uint64_t TSFlags) {
if (TSFlags & SIInstrFlags::TRANS)
// Get the delay type for a MachineInstr.
DelayType getDelayType(const MachineInstr &MI) {
if (SIInstrInfo::isTRANS(MI))
return TRANS;
if (TSFlags & SIInstrFlags::VALU)
// WMMA XDL ops are treated the same as TRANS.
if (AMDGPU::isGFX1250(*ST) && SII->isXDLWMMA(MI))
return TRANS;
if (SIInstrInfo::isVALU(MI))
return VALU;
if (TSFlags & SIInstrFlags::SALU)
if (SIInstrInfo::isSALU(MI))
return SALU;
return OTHER;
}
Expand Down Expand Up @@ -368,7 +372,7 @@ class AMDGPUInsertDelayAlu {
continue;
}

DelayType Type = getDelayType(MI.getDesc().TSFlags);
DelayType Type = getDelayType(MI);

if (instructionWaitsForSGPRWrites(MI)) {
auto It = State.find(LastSGPRFromVALU);
Expand Down Expand Up @@ -456,12 +460,12 @@ class AMDGPUInsertDelayAlu {
LLVM_DEBUG(dbgs() << "AMDGPUInsertDelayAlu running on " << MF.getName()
<< "\n");

const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
if (!ST.hasDelayAlu())
ST = &MF.getSubtarget<GCNSubtarget>();
if (!ST->hasDelayAlu())
return false;

SII = ST.getInstrInfo();
TRI = ST.getRegisterInfo();
SII = ST->getInstrInfo();
TRI = ST->getRegisterInfo();
SchedModel = &SII->getSchedModel();

// Calculate the delay state for each basic block, iterating until we reach
Expand Down
15 changes: 14 additions & 1 deletion llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10466,10 +10466,23 @@ bool SIInstrInfo::isGlobalMemoryObject(const MachineInstr *MI) const {
return TargetInstrInfo::isGlobalMemoryObject(MI);
}

bool SIInstrInfo::isXDLWMMA(const MachineInstr &MI) const {
if (!isWMMA(MI) && !isSWMMAC(MI))
return false;

if (AMDGPU::isGFX1250(ST))
return AMDGPU::getWMMAIsXDL(MI.getOpcode());

return true;
}

bool SIInstrInfo::isXDL(const MachineInstr &MI) const {
unsigned Opcode = MI.getOpcode();

if (!SIInstrInfo::isMAI(MI) || isDGEMM(Opcode) ||
if (AMDGPU::isGFX12Plus(ST))
return isDOT(MI) || isXDLWMMA(MI);

if (!isMAI(MI) || isDGEMM(Opcode) ||
Opcode == AMDGPU::V_ACCVGPR_WRITE_B32_e64 ||
Opcode == AMDGPU::V_ACCVGPR_READ_B32_e64)
return false;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/SIInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,8 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {
return get(Opcode).TSFlags & SIInstrFlags::IsDOT;
}

bool isXDLWMMA(const MachineInstr &MI) const;

bool isXDL(const MachineInstr &MI) const;

static bool isDGEMM(unsigned Opcode) { return AMDGPU::getMAIIsDGEMM(Opcode); }
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ unsigned getCompletionActionImplicitArgPosition(unsigned CodeObjectVersion) {
#define GET_MIMGOffsetMappingTable_IMPL
#define GET_MIMGG16MappingTable_IMPL
#define GET_MAIInstInfoTable_IMPL
#define GET_WMMAInstInfoTable_IMPL
#include "AMDGPUGenSearchableTables.inc"

int getMIMGOpcode(unsigned BaseOpcode, unsigned MIMGEncoding,
Expand Down Expand Up @@ -568,6 +569,11 @@ bool getMAIIsGFX940XDL(unsigned Opc) {
return Info && Info->is_gfx940_xdl;
}

bool getWMMAIsXDL(unsigned Opc) {
const WMMAInstInfo *Info = getWMMAInstInfoHelper(Opc);
return Info ? Info->is_wmma_xdl : false;
}

uint8_t mfmaScaleF8F6F4FormatToNumRegs(unsigned EncodingVal) {
switch (EncodingVal) {
case MFMAScaleFormats::FP6_E2M3:
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ struct True16D16Info {
unsigned LoOp;
};

struct WMMAInstInfo {
uint16_t Opcode;
bool is_wmma_xdl;
};

#define GET_MIMGBaseOpcode_DECL
#define GET_MIMGDim_DECL
#define GET_MIMGEncoding_DECL
Expand All @@ -129,6 +134,7 @@ struct True16D16Info {
#define GET_isMFMA_F8F6F4Table_DECL
#define GET_isCvtScaleF32_F32F16ToF8F4Table_DECL
#define GET_True16D16Table_DECL
#define GET_WMMAInstInfoTable_DECL
#include "AMDGPUGenSearchableTables.inc"

namespace IsaInfo {
Expand Down Expand Up @@ -593,6 +599,9 @@ bool getMAIIsDGEMM(unsigned Opc);
LLVM_READONLY
bool getMAIIsGFX940XDL(unsigned Opc);

LLVM_READONLY
bool getWMMAIsXDL(unsigned Opc);

// Get an equivalent BitOp3 for a binary logical \p Opc.
// \returns BitOp3 modifier for the logical operation or zero.
// Used in VOPD3 conversion.
Expand Down
84 changes: 84 additions & 0 deletions llvm/test/CodeGen/AMDGPU/insert-delay-alu-wmma-xdl.mir
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -start-before=amdgpu-insert-delay-alu %s -o - | FileCheck %s

---
name: wmma_xdl_twoaddr_trans
tracksRegLiveness: true
body: |
bb.0:
; CHECK-LABEL: {{^}}wmma_xdl_twoaddr_trans:
; CHECK: %bb.0:
; CHECK-NEXT: v_wmma_f32_16x16x64_fp8_fp8 v[8:15], v[0:7], v[0:7], v[8:15]
; CHECK-NEXT: v_exp_f32_e32 v16, v16
; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2)
; CHECK-NEXT: v_add_nc_u32_e32 v17, v17, v8
liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, $vgpr16, $vgpr17
$vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = V_WMMA_F32_16X16X64_FP8_FP8_w32_twoaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, 8, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, 0, 0, 0, 0, implicit $exec
$vgpr16 = V_EXP_F32_e32 $vgpr16, implicit $exec, implicit $mode
$vgpr17 = V_ADD_U32_e32 $vgpr17, $vgpr8, implicit $exec
...

---
name: wmma_xdl_threeaddr_trans
tracksRegLiveness: true
body: |
bb.0:
; CHECK-LABEL: {{^}}wmma_xdl_threeaddr_trans:
; CHECK: %bb.0:
; CHECK-NEXT: v_wmma_f32_16x16x64_fp8_fp8 v[8:15], v[0:7], v[0:7], v[16:23]
; CHECK-NEXT: v_exp_f32_e32 v24, v24
; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2)
; CHECK-NEXT: v_add_nc_u32_e32 v25, v25, v8
liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24, $vgpr25
$vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = V_WMMA_F32_16X16X64_FP8_FP8_w32_threeaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, 8, $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, 0, 0, 0, 0, implicit $exec
$vgpr24 = V_EXP_F32_e32 $vgpr24, implicit $exec, implicit $mode
$vgpr25 = V_ADD_U32_e32 $vgpr25, $vgpr8, implicit $exec
...

name: swmmac_xdl_twoaddr_trans
tracksRegLiveness: true
body: |
bb.0:
; CHECK-LABEL: {{^}}swmmac_xdl_twoaddr_trans:
; CHECK: %bb.0:
; CHECK-NEXT: v_swmmac_f16_16x16x128_bf8_bf8 v[24:27], v[0:7], v[8:23], v[28:29]
; CHECK-NEXT: v_exp_f32_e32 v30, v30
; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2)
; CHECK-NEXT: v_add_nc_u32_e32 v31, v31, v24
liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24_vgpr25_vgpr26_vgpr27, $vgpr28, $vgpr29, $vgpr30, $vgpr31
$vgpr24_vgpr25_vgpr26_vgpr27 = V_SWMMAC_F16_16X16X128_BF8_BF8_w32_twoaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24_vgpr25_vgpr26_vgpr27, $vgpr28_vgpr29, 0, 0, 0, implicit $exec
$vgpr30 = V_EXP_F32_e32 $vgpr30, implicit $exec, implicit $mode
$vgpr31 = V_ADD_U32_e32 $vgpr31, $vgpr24, implicit $exec
...

name: wmma_non_xdl_large_data_valu
tracksRegLiveness: true
body: |
bb.0:
; CHECK-LABEL: {{^}}wmma_non_xdl_large_data_valu:
; CHECK: %bb.0:
; CHECK-NEXT: v_wmma_f32_16x16x4_f32 v[4:11], v[0:1], v[2:3], v[4:11] matrix_b_reuse
; CHECK-NEXT: v_exp_f32_e32 v12, v12
; CHECK-NEXT: s_delay_alu instid0(VALU_DEP_1)
; CHECK-NEXT: v_add_nc_u32_e32 v13, v13, v8
liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4_vgpr5_vgpr6_vgpr7, $vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11, $vgpr12, $vgpr13
$vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11 = V_WMMA_F32_16X16X4_F32_w32_twoaddr 8, $vgpr0_vgpr1, 8, $vgpr2_vgpr3, 8, $vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11, 0, -1, 0, 0, implicit $exec
$vgpr12 = V_EXP_F32_e32 $vgpr12, implicit $exec, implicit $mode
$vgpr13 = V_ADD_U32_e32 $vgpr13, $vgpr8, implicit $exec
...

---
name: dot_xdl_dep_2
tracksRegLiveness: true
body: |
bb.0:
; CHECK-LABEL: {{^}}dot_xdl_dep_2:
; CHECK: %bb.0:
; CHECK-NEXT: v_dot4_i32_iu8 v0, s2, s3, v0 neg_lo:[1,1,0]
; CHECK-NEXT: v_dot4_i32_iu8 v1, s2, s3, v2 neg_lo:[1,1,0]
; CHECK-NEXT: s_delay_alu instid0(VALU_DEP_2)
; CHECK-NEXT: v_add_nc_u32_e32 v2, v0, v0
liveins: $vgpr0, $sgpr2, $sgpr3, $vgpr0, $vgpr1, $vgpr2
$vgpr0 = V_DOT4_I32_IU8 9, $sgpr2, 9, $sgpr3, 8, $vgpr0, 0, 0, 0, implicit $exec
$vgpr1 = V_DOT4_I32_IU8 9, $sgpr2, 9, $sgpr3, 8, $vgpr2, 0, 0, 0, implicit $exec
$vgpr2 = V_ADD_U32_e32 $vgpr0, $vgpr0, implicit $exec
...
Loading