Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 201 additions & 104 deletions llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "GCNSubtarget.h"
#include "llvm/CodeGen/GlobalISel/CSEInfo.h"
#include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineUniformityAnalysis.h"
#include "llvm/CodeGen/TargetPassConfig.h"
Expand Down Expand Up @@ -115,126 +116,222 @@ class AMDGPURegBankLegalizeCombiner {
VgprRB(&RBI.getRegBank(AMDGPU::VGPRRegBankID)),
VccRB(&RBI.getRegBank(AMDGPU::VCCRegBankID)) {};

bool isLaneMask(Register Reg) {
const RegisterBank *RB = MRI.getRegBankOrNull(Reg);
if (RB && RB->getID() == AMDGPU::VCCRegBankID)
return true;
bool isLaneMask(Register Reg);
std::pair<MachineInstr *, Register> tryMatch(Register Src, unsigned Opcode);
std::pair<GUnmerge *, int> tryMatchRALFromUnmerge(Register Src);
Register getReadAnyLaneSrc(Register Src);
void replaceRegWithOrBuildCopy(Register Dst, Register Src);
bool tryEliminateReadAnyLane(MachineInstr &Copy);
void tryCombineCopy(MachineInstr &MI);
void tryCombineS1AnyExt(MachineInstr &MI);
};

const TargetRegisterClass *RC = MRI.getRegClassOrNull(Reg);
return RC && TRI.isSGPRClass(RC) && MRI.getType(Reg) == LLT::scalar(1);
}
bool AMDGPURegBankLegalizeCombiner::isLaneMask(Register Reg) {
const RegisterBank *RB = MRI.getRegBankOrNull(Reg);
if (RB && RB->getID() == AMDGPU::VCCRegBankID)
return true;

void cleanUpAfterCombine(MachineInstr &MI, MachineInstr *Optional0) {
MI.eraseFromParent();
if (Optional0 && isTriviallyDead(*Optional0, MRI))
Optional0->eraseFromParent();
}
const TargetRegisterClass *RC = MRI.getRegClassOrNull(Reg);
return RC && TRI.isSGPRClass(RC) && MRI.getType(Reg) == LLT::scalar(1);
}

std::pair<MachineInstr *, Register> tryMatch(Register Src, unsigned Opcode) {
MachineInstr *MatchMI = MRI.getVRegDef(Src);
if (MatchMI->getOpcode() != Opcode)
return {nullptr, Register()};
return {MatchMI, MatchMI->getOperand(1).getReg()};
}
std::pair<MachineInstr *, Register>
AMDGPURegBankLegalizeCombiner::tryMatch(Register Src, unsigned Opcode) {
MachineInstr *MatchMI = MRI.getVRegDef(Src);
if (MatchMI->getOpcode() != Opcode)
return {nullptr, Register()};
return {MatchMI, MatchMI->getOperand(1).getReg()};
}

void tryCombineCopy(MachineInstr &MI) {
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
// Skip copies of physical registers.
if (!Dst.isVirtual() || !Src.isVirtual())
return;

// This is a cross bank copy, sgpr S1 to lane mask.
//
// %Src:sgpr(s1) = G_TRUNC %TruncS32Src:sgpr(s32)
// %Dst:lane-mask(s1) = COPY %Src:sgpr(s1)
// ->
// %Dst:lane-mask(s1) = G_AMDGPU_COPY_VCC_SCC %TruncS32Src:sgpr(s32)
if (isLaneMask(Dst) && MRI.getRegBankOrNull(Src) == SgprRB) {
auto [Trunc, TruncS32Src] = tryMatch(Src, AMDGPU::G_TRUNC);
assert(Trunc && MRI.getType(TruncS32Src) == S32 &&
"sgpr S1 must be result of G_TRUNC of sgpr S32");

B.setInstr(MI);
// Ensure that truncated bits in BoolSrc are 0.
auto One = B.buildConstant({SgprRB, S32}, 1);
auto BoolSrc = B.buildAnd({SgprRB, S32}, TruncS32Src, One);
B.buildInstr(AMDGPU::G_AMDGPU_COPY_VCC_SCC, {Dst}, {BoolSrc});
cleanUpAfterCombine(MI, Trunc);
return;
}
std::pair<GUnmerge *, int>
AMDGPURegBankLegalizeCombiner::tryMatchRALFromUnmerge(Register Src) {
MachineInstr *ReadAnyLane = MRI.getVRegDef(Src);
if (ReadAnyLane->getOpcode() != AMDGPU::G_AMDGPU_READANYLANE)
return {nullptr, -1};

Register RALSrc = ReadAnyLane->getOperand(1).getReg();
if (auto *UnMerge = getOpcodeDef<GUnmerge>(RALSrc, MRI))
return {UnMerge, UnMerge->findRegisterDefOperandIdx(RALSrc, nullptr)};

// Src = G_AMDGPU_READANYLANE RALSrc
// Dst = COPY Src
// ->
// Dst = RALSrc
if (MRI.getRegBankOrNull(Dst) == VgprRB &&
MRI.getRegBankOrNull(Src) == SgprRB) {
auto [RAL, RALSrc] = tryMatch(Src, AMDGPU::G_AMDGPU_READANYLANE);
if (!RAL)
return;

assert(MRI.getRegBank(RALSrc) == VgprRB);
MRI.replaceRegWith(Dst, RALSrc);
cleanUpAfterCombine(MI, RAL);
return;
return {nullptr, -1};
}

Register AMDGPURegBankLegalizeCombiner::getReadAnyLaneSrc(Register Src) {
// Src = G_AMDGPU_READANYLANE RALSrc
auto [RAL, RALSrc] = tryMatch(Src, AMDGPU::G_AMDGPU_READANYLANE);
if (RAL)
return RALSrc;

// LoVgpr, HiVgpr = G_UNMERGE_VALUES UnmergeSrc
// LoSgpr = G_AMDGPU_READANYLANE LoVgpr
// HiSgpr = G_AMDGPU_READANYLANE HiVgpr
// Src G_MERGE_VALUES LoSgpr, HiSgpr
auto *Merge = getOpcodeDef<GMergeLikeInstr>(Src, MRI);
if (Merge) {
unsigned NumElts = Merge->getNumSources();
auto [Unmerge, Idx] = tryMatchRALFromUnmerge(Merge->getSourceReg(0));
if (!Unmerge || Unmerge->getNumDefs() != NumElts || Idx != 0)
return {};

// Check if all elements are from same unmerge and there is no shuffling.
for (unsigned i = 1; i < NumElts; ++i) {
auto [UnmergeI, IdxI] = tryMatchRALFromUnmerge(Merge->getSourceReg(i));
if (UnmergeI != Unmerge || (unsigned)IdxI != i)
return {};
}
return Unmerge->getSourceReg();
}

void tryCombineS1AnyExt(MachineInstr &MI) {
// %Src:sgpr(S1) = G_TRUNC %TruncSrc
// %Dst = G_ANYEXT %Src:sgpr(S1)
// ->
// %Dst = G_... %TruncSrc
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
if (MRI.getType(Src) != S1)
return;

auto [Trunc, TruncSrc] = tryMatch(Src, AMDGPU::G_TRUNC);
if (!Trunc)
return;

LLT DstTy = MRI.getType(Dst);
LLT TruncSrcTy = MRI.getType(TruncSrc);

if (DstTy == TruncSrcTy) {
MRI.replaceRegWith(Dst, TruncSrc);
cleanUpAfterCombine(MI, Trunc);
return;
}
// ..., VgprI, ... = G_UNMERGE_VALUES VgprLarge
// SgprI = G_AMDGPU_READANYLANE VgprI
// SgprLarge G_MERGE_VALUES ..., SgprI, ...
// ..., Src, ... = G_UNMERGE_VALUES SgprLarge
auto *UnMerge = getOpcodeDef<GUnmerge>(Src, MRI);
if (!UnMerge)
return {};

int Idx = UnMerge->findRegisterDefOperandIdx(Src, nullptr);
Merge = getOpcodeDef<GMergeLikeInstr>(UnMerge->getSourceReg(), MRI);
if (!Merge)
return {};

auto [RALElt, RALEltSrc] =
tryMatch(Merge->getSourceReg(Idx), AMDGPU::G_AMDGPU_READANYLANE);
if (RALElt)
return RALEltSrc;

return {};
}

void AMDGPURegBankLegalizeCombiner::replaceRegWithOrBuildCopy(Register Dst,
Register Src) {
if (Dst.isVirtual())
MRI.replaceRegWith(Dst, Src);
else
B.buildCopy(Dst, Src);
}

bool AMDGPURegBankLegalizeCombiner::tryEliminateReadAnyLane(
MachineInstr &Copy) {
Register Dst = Copy.getOperand(0).getReg();
Register Src = Copy.getOperand(1).getReg();
if (!Src.isVirtual())
return false;

Register RALDst = Src;
MachineInstr &SrcMI = *MRI.getVRegDef(Src);
if (SrcMI.getOpcode() == AMDGPU::G_BITCAST)
RALDst = SrcMI.getOperand(1).getReg();

Register RALSrc = getReadAnyLaneSrc(RALDst);
if (!RALSrc)
return false;

B.setInstr(Copy);
if (SrcMI.getOpcode() != AMDGPU::G_BITCAST) {
// Src = READANYLANE RALSrc Src = READANYLANE RALSrc
// Dst = Copy Src $Dst = Copy Src
// -> ->
// Dst = RALSrc $Dst = Copy RALSrc
replaceRegWithOrBuildCopy(Dst, RALSrc);
} else {
// RALDst = READANYLANE RALSrc RALDst = READANYLANE RALSrc
// Src = G_BITCAST RALDst Src = G_BITCAST RALDst
// Dst = Copy Src Dst = Copy Src
// -> ->
// NewVgpr = G_BITCAST RALDst NewVgpr = G_BITCAST RALDst
// Dst = NewVgpr $Dst = Copy NewVgpr
auto Bitcast = B.buildBitcast({VgprRB, MRI.getType(Src)}, RALSrc);
replaceRegWithOrBuildCopy(Dst, Bitcast.getReg(0));
}

eraseInstr(Copy, MRI, nullptr);
return true;
}

void AMDGPURegBankLegalizeCombiner::tryCombineCopy(MachineInstr &MI) {
if (tryEliminateReadAnyLane(MI))
return;

Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
// Skip copies of physical registers.
if (!Dst.isVirtual() || !Src.isVirtual())
return;

// This is a cross bank copy, sgpr S1 to lane mask.
//
// %Src:sgpr(s1) = G_TRUNC %TruncS32Src:sgpr(s32)
// %Dst:lane-mask(s1) = COPY %Src:sgpr(s1)
// ->
// %Dst:lane-mask(s1) = G_AMDGPU_COPY_VCC_SCC %TruncS32Src:sgpr(s32)
if (isLaneMask(Dst) && MRI.getRegBankOrNull(Src) == SgprRB) {
auto [Trunc, TruncS32Src] = tryMatch(Src, AMDGPU::G_TRUNC);
assert(Trunc && MRI.getType(TruncS32Src) == S32 &&
"sgpr S1 must be result of G_TRUNC of sgpr S32");

B.setInstr(MI);
// Ensure that truncated bits in BoolSrc are 0.
auto One = B.buildConstant({SgprRB, S32}, 1);
auto BoolSrc = B.buildAnd({SgprRB, S32}, TruncS32Src, One);
B.buildInstr(AMDGPU::G_AMDGPU_COPY_VCC_SCC, {Dst}, {BoolSrc});
eraseInstr(MI, MRI, nullptr);
}
}

if (DstTy == S32 && TruncSrcTy == S64) {
auto Unmerge = B.buildUnmerge({SgprRB, S32}, TruncSrc);
MRI.replaceRegWith(Dst, Unmerge.getReg(0));
cleanUpAfterCombine(MI, Trunc);
return;
}
void AMDGPURegBankLegalizeCombiner::tryCombineS1AnyExt(MachineInstr &MI) {
// %Src:sgpr(S1) = G_TRUNC %TruncSrc
// %Dst = G_ANYEXT %Src:sgpr(S1)
// ->
// %Dst = G_... %TruncSrc
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
if (MRI.getType(Src) != S1)
return;

auto [Trunc, TruncSrc] = tryMatch(Src, AMDGPU::G_TRUNC);
if (!Trunc)
return;

LLT DstTy = MRI.getType(Dst);
LLT TruncSrcTy = MRI.getType(TruncSrc);

if (DstTy == TruncSrcTy) {
MRI.replaceRegWith(Dst, TruncSrc);
eraseInstr(MI, MRI, nullptr);
return;
}

if (DstTy == S64 && TruncSrcTy == S32) {
B.buildMergeLikeInstr(MI.getOperand(0).getReg(),
{TruncSrc, B.buildUndef({SgprRB, S32})});
cleanUpAfterCombine(MI, Trunc);
return;
}
B.setInstr(MI);

if (DstTy == S32 && TruncSrcTy == S16) {
B.buildAnyExt(Dst, TruncSrc);
cleanUpAfterCombine(MI, Trunc);
return;
}
if (DstTy == S32 && TruncSrcTy == S64) {
auto Unmerge = B.buildUnmerge({SgprRB, S32}, TruncSrc);
MRI.replaceRegWith(Dst, Unmerge.getReg(0));
eraseInstr(MI, MRI, nullptr);
return;
}

if (DstTy == S16 && TruncSrcTy == S32) {
B.buildTrunc(Dst, TruncSrc);
cleanUpAfterCombine(MI, Trunc);
return;
}
if (DstTy == S64 && TruncSrcTy == S32) {
B.buildMergeLikeInstr(MI.getOperand(0).getReg(),
{TruncSrc, B.buildUndef({SgprRB, S32})});
eraseInstr(MI, MRI, nullptr);
return;
}

llvm_unreachable("missing anyext + trunc combine");
if (DstTy == S32 && TruncSrcTy == S16) {
B.buildAnyExt(Dst, TruncSrc);
eraseInstr(MI, MRI, nullptr);
return;
}
};

if (DstTy == S16 && TruncSrcTy == S32) {
B.buildTrunc(Dst, TruncSrc);
eraseInstr(MI, MRI, nullptr);
return;
}

llvm_unreachable("missing anyext + trunc combine");
}

// Search through MRI for virtual registers with sgpr register bank and S1 LLT.
[[maybe_unused]] static Register getAnySgprS1(const MachineRegisterInfo &MRI) {
Expand Down
Loading