Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
43 changes: 43 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/InferTypeInfoPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef LLVM_CODEGEN_GLOBALISEL_INFERTYPEINFOPASS_H
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this is necessary, can we keep the initial patch to just the LLT changes

#define LLVM_CODEGEN_GLOBALISEL_INFERTYPEINFOPASS_H

#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"

namespace llvm {

class InferTypeInfo : public MachineFunctionPass {
public:
static char ID;

private:
MachineRegisterInfo *MRI = nullptr;
MachineFunction *MF = nullptr;

MachineIRBuilder Builder;

/// Initialize the field members using \p MF.
void init(MachineFunction &MF);

public:
InferTypeInfo() : MachineFunctionPass(ID) {}

void getAnalysisUsage(AnalysisUsage &AU) const override;

bool runOnMachineFunction(MachineFunction &MF) override;

private:
bool inferTypeInfo(MachineFunction &MF);

bool shouldBeFP(MachineOperand &Op, unsigned Depth) const;

void updateDef(Register Reg);

void updateUse(MachineOperand &Op, bool FP);
};

} // end namespace llvm

#endif // LLVM_CODEGEN_GLOBALISEL_INFERTYPEINFOPASS_H
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "legalizer"
Expand All @@ -44,6 +45,7 @@ class LegalizationArtifactCombiner {
case TargetOpcode::G_SEXT:
case TargetOpcode::G_ZEXT:
case TargetOpcode::G_ANYEXT:
case TargetOpcode::G_BITCAST:
return true;
default:
return false;
Expand Down Expand Up @@ -507,6 +509,53 @@ class LegalizationArtifactCombiner {
markInstAndDefDead(MI, CastMI, DeadInsts);
return true;
}
} else if (CastOpc == TargetOpcode::G_BITCAST) {

// %1:_(<2 x i32>) = G_BITCAST %0(<2 x f32>)
// %2:_(i16), %3:_(i16), %4:_(i16), %5:_(i16) = G_UNMERGE_VALUES %1
// =>
// %6:_(f32), %7:_(f32) = G_UNMERGE_VALUES %0
// %8:_(i32) = G_BITCAST %6
// %2:_(i16), %3:_(i16) = G_UNMERGE_VALUES %8
// %9:_(i32) = G_BITCAST %7
// %4:_(i16), %5:_(i16) = G_UNMERGE_VALUES %9

if (CastSrcTy.isScalar() || SrcTy.isScalar() || DestTy.isVector() || DestTy == SrcTy.getScalarType())
return false;

const unsigned NewNumDefs1 = CastSrcTy.getNumElements();
const unsigned NewNumDefs2 = NumDefs / NewNumDefs1;

if (NewNumDefs2 <= 1)
return false;

SmallVector<Register, 8> NewUnmergeRegs(NewNumDefs1);
for (unsigned Idx = 0; Idx < NewNumDefs1; ++Idx)
NewUnmergeRegs[Idx] = MRI.createGenericVirtualRegister(CastSrcTy.getElementType());

Builder.setInstr(MI);
auto NewUnmerge = Builder.buildUnmerge(NewUnmergeRegs, CastSrcReg);


SmallVector<Register, 8> DstRegs(NumDefs);
for (unsigned Idx = 0; Idx < NumDefs; ++Idx)
DstRegs[Idx] = MI.getOperand(Idx).getReg();


auto* It = DstRegs.begin();

for (auto& Def : NewUnmerge->all_defs()) {
auto Bitcast = Builder.buildBitcast(SrcTy.getElementType(), Def);
auto* Begin = It;
It += NewNumDefs2;
ArrayRef Regs(Begin, It);
Builder.buildUnmerge(Regs, Bitcast);
}

UpdatedDefs.append(NewUnmergeRegs.begin(), NewUnmergeRegs.end());
UpdatedDefs.append(DstRegs.begin(), DstRegs.end());
markInstAndDefDead(MI, CastMI, DeadInsts);
return true;
}

// TODO: support combines with other casts as well
Expand Down Expand Up @@ -1165,8 +1214,9 @@ class LegalizationArtifactCombiner {
++j, ++DefIdx)
DstRegs.push_back(MI.getReg(DefIdx));

if (ConvertOp) {
LLT MergeDstTy = MRI.getType(SrcDef->getOperand(0).getReg());
LLT MergeDstTy = MRI.getType(SrcDef->getOperand(0).getReg());

if (ConvertOp && DestTy != MergeDstTy) {

// This is a vector that is being split and casted. Extract to the
// element type, and do the conversion on the scalars (or smaller
Expand All @@ -1187,6 +1237,7 @@ class LegalizationArtifactCombiner {
// %7(<2 x s16>), %7(<2 x s16>) = G_UNMERGE_VALUES %9

Register TmpReg = MRI.createGenericVirtualRegister(MergeEltTy);
assert(MRI.getType(TmpReg) != MRI.getType(MergeI->getOperand(Idx + 1).getReg()));
Builder.buildInstr(ConvertOp, {TmpReg},
{MergeI->getOperand(Idx + 1).getReg()});
Builder.buildUnmerge(DstRegs, TmpReg);
Expand Down Expand Up @@ -1232,14 +1283,15 @@ class LegalizationArtifactCombiner {
ConvertOp = TargetOpcode::G_BITCAST;
}

if (ConvertOp) {
if (ConvertOp && DestTy != MergeSrcTy) {
Builder.setInstr(MI);

for (unsigned Idx = 0; Idx < NumDefs; ++Idx) {
Register DefReg = MI.getOperand(Idx).getReg();
Register MergeSrc = MergeI->getOperand(Idx + 1).getReg();

if (!MRI.use_empty(DefReg)) {
assert(MRI.getType(DefReg) != MRI.getType(MergeSrc));
Builder.buildInstr(ConvertOp, {DefReg}, {MergeSrc});
UpdatedDefs.push_back(DefReg);
}
Expand Down Expand Up @@ -1398,6 +1450,7 @@ class LegalizationArtifactCombiner {
case TargetOpcode::G_EXTRACT:
case TargetOpcode::G_TRUNC:
case TargetOpcode::G_BUILD_VECTOR:
case TargetOpcode::G_BITCAST:
// Adding Use to ArtifactList.
WrapperObserver.changedInstr(Use);
break;
Expand Down Expand Up @@ -1425,6 +1478,7 @@ class LegalizationArtifactCombiner {
static Register getArtifactSrcReg(const MachineInstr &MI) {
switch (MI.getOpcode()) {
case TargetOpcode::COPY:
case TargetOpcode::G_BITCAST:
case TargetOpcode::G_TRUNC:
case TargetOpcode::G_ZEXT:
case TargetOpcode::G_ANYEXT:
Expand Down
111 changes: 98 additions & 13 deletions llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ LegalityPredicate typePairAndMemDescInSet(
std::initializer_list<TypePairAndMemDesc> TypesAndMemDescInit);
/// True iff the specified type index is a scalar.
LegalityPredicate isScalar(unsigned TypeIdx);
/// True iff the specified type index is a integer.
LegalityPredicate isInteger(unsigned TypeIdx);
/// True iff the specified type index is a float.
LegalityPredicate isFloat(unsigned TypeIdx);
/// True iff the specified type index is a vector.
LegalityPredicate isVector(unsigned TypeIdx);
/// True iff the specified type index is a pointer (with any address space).
Expand All @@ -292,6 +296,14 @@ LegalityPredicate isPointer(unsigned TypeIdx, unsigned AddrSpace);
/// True iff the specified type index is a vector of pointers (with any address
/// space).
LegalityPredicate isPointerVector(unsigned TypeIdx);
/// True iff the specified type index is a vector of integer
LegalityPredicate isIntegerVector(unsigned TypeIdx);
/// True iff the specified type index is a vector of floats.
LegalityPredicate isFloatVector(unsigned TypeIdx);

LegalityPredicate isFloatOrFloatVector(unsigned TypeIdx);

LegalityPredicate isIntegerOrIntegerVector(unsigned TypeIdx);

/// True if the type index is a vector with element type \p EltTy
LegalityPredicate elementTypeIs(unsigned TypeIdx, LLT EltTy);
Expand Down Expand Up @@ -330,6 +342,10 @@ LegalityPredicate sizeIs(unsigned TypeIdx, unsigned Size);
/// True iff the specified type indices are both the same bit size.
LegalityPredicate sameSize(unsigned TypeIdx0, unsigned TypeIdx1);

LegalityPredicate sameScalarKind(unsigned TypeIdx, LLT Ty);

LegalityPredicate sameKind(unsigned TypeIdx, LLT Ty);

/// True iff the first type index has a larger total bit size than second type
/// index.
LegalityPredicate largerThan(unsigned TypeIdx0, unsigned TypeIdx1);
Expand Down Expand Up @@ -374,13 +390,15 @@ LegalizeMutation changeElementCountTo(unsigned TypeIdx, unsigned FromTypeIdx);

/// Keep the same scalar or element type as \p TypeIdx, but take the number of
/// elements from \p Ty.
LegalizeMutation changeElementCountTo(unsigned TypeIdx, LLT Ty);
LegalizeMutation changeElementCountTo(unsigned TypeIdx, ElementCount EC);

/// Change the scalar size or element size to have the same scalar size as type
/// index \p FromIndex. Unlike changeElementTo, this discards pointer types and
/// only changes the size.
LegalizeMutation changeElementSizeTo(unsigned TypeIdx, unsigned FromTypeIdx);

LegalizeMutation changeToInteger(unsigned TypeIdx);

/// Widen the scalar type or vector element type for the given type index to the
/// next power of 2.
LegalizeMutation widenScalarOrEltToNextPow2(unsigned TypeIdx, unsigned Min = 0);
Expand Down Expand Up @@ -942,6 +960,16 @@ class LegalizeRuleSet {
LegalizeMutations::widenScalarOrEltToNextPow2(TypeIdx, MinSize));
}

LegalizeRuleSet &widenScalarToNextPow2Bitcast(unsigned TypeIdx,
unsigned MinSize = 0) {
using namespace LegalityPredicates;
using namespace LegalizeMutations;
return actionIf(
LegalizeAction::Bitcast,
all(isFloatOrFloatVector(TypeIdx), sizeNotPow2(typeIdx(TypeIdx))),
changeToInteger(TypeIdx));
}

/// Widen the scalar to the next multiple of Size. No effect if the
/// type is not a scalar or is a multiple of Size.
LegalizeRuleSet &widenScalarToNextMultipleOf(unsigned TypeIdx,
Expand Down Expand Up @@ -997,20 +1025,32 @@ class LegalizeRuleSet {
LegalizeRuleSet &minScalarOrElt(unsigned TypeIdx, const LLT Ty) {
using namespace LegalityPredicates;
using namespace LegalizeMutations;
return actionIf(LegalizeAction::WidenScalar,
scalarOrEltNarrowerThan(TypeIdx, Ty.getScalarSizeInBits()),
changeElementTo(typeIdx(TypeIdx), Ty));
return actionIf(
LegalizeAction::WidenScalar,
all(sameScalarKind(TypeIdx, Ty),
scalarOrEltNarrowerThan(TypeIdx, Ty.getScalarSizeInBits())),
changeElementTo(typeIdx(TypeIdx), Ty));
}
LegalizeRuleSet &minScalarOrEltBitcast(unsigned TypeIdx, const LLT Ty) {
using namespace LegalityPredicates;
using namespace LegalizeMutations;
return actionIf(
LegalizeAction::Bitcast,
all(isFloatOrFloatVector(TypeIdx),
scalarOrEltNarrowerThan(TypeIdx, Ty.getScalarSizeInBits())),
changeToInteger(typeIdx(TypeIdx)));
}

/// Ensure the scalar or element is at least as wide as Ty.
LegalizeRuleSet &minScalarOrEltIf(LegalityPredicate Predicate,
unsigned TypeIdx, const LLT Ty) {
using namespace LegalityPredicates;
using namespace LegalizeMutations;
return actionIf(LegalizeAction::WidenScalar,
all(Predicate, scalarOrEltNarrowerThan(
TypeIdx, Ty.getScalarSizeInBits())),
changeElementTo(typeIdx(TypeIdx), Ty));
return actionIf(
LegalizeAction::WidenScalar,
all(Predicate, sameScalarKind(TypeIdx, Ty),
scalarOrEltNarrowerThan(TypeIdx, Ty.getScalarSizeInBits())),
changeElementTo(typeIdx(TypeIdx), Ty));
}

/// Ensure the vector size is at least as wide as VectorSize by promoting the
Expand Down Expand Up @@ -1039,14 +1079,23 @@ class LegalizeRuleSet {
using namespace LegalityPredicates;
using namespace LegalizeMutations;
return actionIf(LegalizeAction::WidenScalar,
scalarNarrowerThan(TypeIdx, Ty.getSizeInBits()),
all(sameKind(TypeIdx, Ty),
scalarNarrowerThan(TypeIdx, Ty.getSizeInBits())),
changeTo(typeIdx(TypeIdx), Ty));
}
LegalizeRuleSet &minScalar(bool Pred, unsigned TypeIdx, const LLT Ty) {
if (!Pred)
return *this;
return minScalar(TypeIdx, Ty);
}
LegalizeRuleSet &minScalarBitcast(unsigned TypeIdx, const LLT Ty) {
using namespace LegalityPredicates;
using namespace LegalizeMutations;
return actionIf(LegalizeAction::Bitcast,
all(isFloatOrFloatVector(TypeIdx),
scalarNarrowerThan(TypeIdx, Ty.getSizeInBits())),
changeToInteger(typeIdx(TypeIdx)));
}

/// Ensure the scalar is at least as wide as Ty if condition is met.
LegalizeRuleSet &minScalarIf(LegalityPredicate Predicate, unsigned TypeIdx,
Expand All @@ -1068,19 +1117,39 @@ class LegalizeRuleSet {
LegalizeRuleSet &maxScalarOrElt(unsigned TypeIdx, const LLT Ty) {
using namespace LegalityPredicates;
using namespace LegalizeMutations;
return actionIf(LegalizeAction::NarrowScalar,
scalarOrEltWiderThan(TypeIdx, Ty.getScalarSizeInBits()),
changeElementTo(typeIdx(TypeIdx), Ty));
return actionIf(
LegalizeAction::NarrowScalar,
all(sameScalarKind(TypeIdx, Ty),
scalarOrEltWiderThan(TypeIdx, Ty.getScalarSizeInBits())),
changeElementTo(typeIdx(TypeIdx), Ty));
}
LegalizeRuleSet &maxScalarOrEltBitcast(unsigned TypeIdx, const LLT Ty) {
using namespace LegalityPredicates;
using namespace LegalizeMutations;
return actionIf(
LegalizeAction::NarrowScalar,
all(isFloatOrFloatVector(TypeIdx),
scalarOrEltWiderThan(TypeIdx, Ty.getScalarSizeInBits())),
changeToInteger(typeIdx(TypeIdx)));
}

/// Ensure the scalar is at most as wide as Ty.
LegalizeRuleSet &maxScalar(unsigned TypeIdx, const LLT Ty) {
using namespace LegalityPredicates;
using namespace LegalizeMutations;
return actionIf(LegalizeAction::NarrowScalar,
scalarWiderThan(TypeIdx, Ty.getSizeInBits()),
all(sameKind(TypeIdx, Ty),
scalarWiderThan(TypeIdx, Ty.getSizeInBits())),
changeTo(typeIdx(TypeIdx), Ty));
}
LegalizeRuleSet &maxScalarBitcast(unsigned TypeIdx, const LLT Ty) {
using namespace LegalityPredicates;
using namespace LegalizeMutations;
return actionIf(LegalizeAction::NarrowScalar,
all(isFloatOrFloatVector(TypeIdx),
scalarWiderThan(TypeIdx, Ty.getSizeInBits())),
changeToInteger(typeIdx(TypeIdx)));
}

/// Conditionally limit the maximum size of the scalar.
/// For example, when the maximum size of one type depends on the size of
Expand All @@ -1103,10 +1172,20 @@ class LegalizeRuleSet {
/// Limit the range of scalar sizes to MinTy and MaxTy.
LegalizeRuleSet &clampScalar(unsigned TypeIdx, const LLT MinTy,
const LLT MaxTy) {
assert(MinTy.getKind() == MaxTy.getKind() &&
"Expected LLT of the same kind");
assert(MinTy.isScalar() && MaxTy.isScalar() && "Expected scalar types");
return minScalar(TypeIdx, MinTy).maxScalar(TypeIdx, MaxTy);
}

LegalizeRuleSet &clampScalarBitcast(unsigned TypeIdx, const LLT MinTy,
const LLT MaxTy) {
assert(MinTy.getKind() == MaxTy.getKind() &&
"Expected LLT of the same kind");
assert(MinTy.isScalar() && MaxTy.isScalar() && "Expected scalar types");
return minScalarBitcast(TypeIdx, MinTy).maxScalarBitcast(TypeIdx, MaxTy);
}

LegalizeRuleSet &clampScalar(bool Pred, unsigned TypeIdx, const LLT MinTy,
const LLT MaxTy) {
if (!Pred)
Expand All @@ -1120,6 +1199,12 @@ class LegalizeRuleSet {
return minScalarOrElt(TypeIdx, MinTy).maxScalarOrElt(TypeIdx, MaxTy);
}

LegalizeRuleSet &clampScalarOrEltBitcast(unsigned TypeIdx, const LLT MinTy,
const LLT MaxTy) {
return minScalarOrEltBitcast(TypeIdx, MinTy)
.maxScalarOrEltBitcast(TypeIdx, MaxTy);
}

/// Widen the scalar to match the size of another.
LegalizeRuleSet &minScalarSameAs(unsigned TypeIdx, unsigned LargeTypeIdx) {
typeIdx(TypeIdx);
Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class MachineIRBuilder {
unsigned getOpcodeForMerge(const DstOp &DstOp, ArrayRef<SrcOp> SrcOps) const;

protected:
void validateTruncExt(const LLT Dst, const LLT Src, bool IsExtend);
void validateTruncExt(const LLT Dst, const LLT Src, unsigned Opc);

void validateUnaryOp(const LLT Res, const LLT Op0);
void validateBinaryOp(const LLT Res, const LLT Op0, const LLT Op1);
Expand Down Expand Up @@ -802,6 +802,8 @@ class MachineIRBuilder {
MachineInstrBuilder buildExtOrTrunc(unsigned ExtOpc, const DstOp &Res,
const SrcOp &Op);

MachineInstrBuilder buildTruncLike(const DstOp &Res, const SrcOp &Op);

/// Build and inserts \p Res = \p G_AND \p Op, \p LowBitsSet(ImmOp)
/// Since there is no G_ZEXT_INREG like G_SEXT_INREG, the instruction is
/// emulated using G_AND.
Expand Down
Loading
Loading