Skip to content

Commit cd0a8e7

Browse files
committed
FPInfo: IRTranslator and CallLowering
1 parent 8bb477a commit cd0a8e7

File tree

6 files changed

+117
-76
lines changed

6 files changed

+117
-76
lines changed

llvm/include/llvm/CodeGen/Analysis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ inline void ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL,
9595
/// with the in-memory offsets of each of the individual values.
9696
///
9797
void computeValueLLTs(const DataLayout &DL, Type &Ty,
98-
SmallVectorImpl<LLT> &ValueTys,
98+
SmallVectorImpl<LLT> &ValueTys, bool EnableFPInfo,
9999
SmallVectorImpl<uint64_t> *Offsets = nullptr,
100100
uint64_t StartingOffset = 0);
101101

llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,8 @@ class IRTranslator : public MachineFunctionPass {
619619

620620
CodeGenOptLevel OptLevel;
621621

622+
bool EnableFPInfo;
623+
622624
/// Current optimization remark emitter. Used to report failures.
623625
std::unique_ptr<OptimizationRemarkEmitter> ORE;
624626

@@ -772,7 +774,7 @@ class IRTranslator : public MachineFunctionPass {
772774
BranchProbability Prob = BranchProbability::getUnknown());
773775

774776
public:
775-
IRTranslator(CodeGenOptLevel OptLevel = CodeGenOptLevel::None);
777+
IRTranslator(CodeGenOptLevel OptLevel = CodeGenOptLevel::None, bool EnableFPInfo = false);
776778

777779
StringRef getPassName() const override { return "IRTranslator"; }
778780

llvm/lib/CodeGen/Analysis.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ void llvm::ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL,
139139
}
140140

141141
void llvm::computeValueLLTs(const DataLayout &DL, Type &Ty,
142-
SmallVectorImpl<LLT> &ValueTys,
142+
SmallVectorImpl<LLT> &ValueTys, bool EnableFPInfo,
143143
SmallVectorImpl<uint64_t> *Offsets,
144144
uint64_t StartingOffset) {
145145
// Given a struct type, recursively traverse the elements.
@@ -150,7 +150,7 @@ void llvm::computeValueLLTs(const DataLayout &DL, Type &Ty,
150150
const StructLayout *SL = Offsets ? DL.getStructLayout(STy) : nullptr;
151151
for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) {
152152
uint64_t EltOffset = SL ? SL->getElementOffset(I) : 0;
153-
computeValueLLTs(DL, *STy->getElementType(I), ValueTys, Offsets,
153+
computeValueLLTs(DL, *STy->getElementType(I), ValueTys, EnableFPInfo, Offsets,
154154
StartingOffset + EltOffset);
155155
}
156156
return;
@@ -160,15 +160,15 @@ void llvm::computeValueLLTs(const DataLayout &DL, Type &Ty,
160160
Type *EltTy = ATy->getElementType();
161161
uint64_t EltSize = DL.getTypeAllocSize(EltTy).getFixedValue();
162162
for (unsigned i = 0, e = ATy->getNumElements(); i != e; ++i)
163-
computeValueLLTs(DL, *EltTy, ValueTys, Offsets,
163+
computeValueLLTs(DL, *EltTy, ValueTys, EnableFPInfo, Offsets,
164164
StartingOffset + i * EltSize);
165165
return;
166166
}
167167
// Interpret void as zero return values.
168168
if (Ty.isVoidTy())
169169
return;
170170
// Base case: we can get an LLT for this LLVM IR type.
171-
ValueTys.push_back(getLLTForType(Ty, DL));
171+
ValueTys.push_back(getLLTForType(Ty, DL, EnableFPInfo));
172172
if (Offsets != nullptr)
173173
Offsets->push_back(StartingOffset * 8);
174174
}

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/CodeGen/MachineOperand.h"
2121
#include "llvm/CodeGen/MachineRegisterInfo.h"
2222
#include "llvm/CodeGen/TargetLowering.h"
23+
#include "llvm/CodeGen/TargetOpcodes.h"
2324
#include "llvm/IR/DataLayout.h"
2425
#include "llvm/IR/LLVMContext.h"
2526
#include "llvm/IR/Module.h"
@@ -158,7 +159,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
158159

159160
if (const Function *F = dyn_cast<Function>(CalleeV)) {
160161
if (F->hasFnAttribute(Attribute::NonLazyBind)) {
161-
LLT Ty = getLLTForType(*F->getType(), DL);
162+
LLT Ty = getLLTForType(*F->getType(), DL, /* EnableFPInfo */ true);
162163
Register Reg = MIRBuilder.buildGlobalValue(Ty, F).getReg(0);
163164
Info.Callee = MachineOperand::CreateReg(Reg, false);
164165
} else {
@@ -573,8 +574,13 @@ static void buildCopyToRegs(MachineIRBuilder &B, ArrayRef<Register> DstRegs,
573574
TypeSize::isKnownGT(PartSize, SrcTy.getElementType().getSizeInBits())) {
574575
// Vector was scalarized, and the elements extended.
575576
auto UnmergeToEltTy = B.buildUnmerge(SrcTy.getElementType(), SrcReg);
576-
for (int i = 0, e = DstRegs.size(); i != e; ++i)
577-
B.buildAnyExt(DstRegs[i], UnmergeToEltTy.getReg(i));
577+
for (int i = 0, e = DstRegs.size(); i != e; ++i) {
578+
if (SrcTy.isFloatVector() && ExtendOp == TargetOpcode::G_FPEXT) {
579+
B.buildFPExt(DstRegs[i], UnmergeToEltTy.getReg(i));
580+
} else {
581+
B.buildAnyExt(DstRegs[i], UnmergeToEltTy.getReg(i));
582+
}
583+
}
578584
return;
579585
}
580586

@@ -599,9 +605,16 @@ static void buildCopyToRegs(MachineIRBuilder &B, ArrayRef<Register> DstRegs,
599605
SrcTy.getScalarSizeInBits() > PartTy.getSizeInBits()) {
600606
LLT ExtTy =
601607
LLT::vector(SrcTy.getElementCount(),
602-
LLT::scalar(PartTy.getScalarSizeInBits() * DstRegs.size() /
608+
LLT::integer(PartTy.getScalarSizeInBits() * DstRegs.size() /
603609
SrcTy.getNumElements()));
604-
auto Ext = B.buildAnyExt(ExtTy, SrcReg);
610+
Register Ext;
611+
if (SrcTy.isFloatVector() && ExtendOp == TargetOpcode::G_FPEXT) {
612+
auto Cast = B.buildBitcast(SrcTy.dropType(), SrcReg).getReg(0);
613+
Ext = B.buildAnyExt(ExtTy, Cast).getReg(0);
614+
} else {
615+
Ext = B.buildAnyExt(ExtTy, SrcReg).getReg(0);
616+
}
617+
605618
B.buildUnmerge(DstRegs, Ext);
606619
return;
607620
}
@@ -780,11 +793,11 @@ bool CallLowering::handleAssignments(ValueHandler &Handler,
780793
const MVT ValVT = VA.getValVT();
781794
const MVT LocVT = VA.getLocVT();
782795

783-
const LLT LocTy(LocVT);
784-
const LLT ValTy(ValVT);
796+
const LLT LocTy(LocVT, /* EnableFPInfo */ true);
797+
const LLT ValTy(ValVT, /* EnableFPInfo */ true);
785798
const LLT NewLLT = Handler.isIncomingArgumentHandler() ? LocTy : ValTy;
786799
const EVT OrigVT = EVT::getEVT(Args[i].Ty);
787-
const LLT OrigTy = getLLTForType(*Args[i].Ty, DL);
800+
const LLT OrigTy = getLLTForType(*Args[i].Ty, DL, /* EnableFPInfo */ true);
788801
const LLT PointerTy = LLT::pointer(
789802
AllocaAddressSpace, DL.getPointerSizeInBits(AllocaAddressSpace));
790803

@@ -822,8 +835,11 @@ bool CallLowering::handleAssignments(ValueHandler &Handler,
822835
if (!Handler.isIncomingArgumentHandler() && OrigTy != ValTy &&
823836
VA.getLocInfo() != CCValAssign::Indirect) {
824837
assert(Args[i].OrigRegs.size() == 1);
838+
unsigned ExtendOp = extendOpFromFlags(Args[i].Flags[0]);
839+
if ((OrigTy.isFloat() || OrigTy.isFloatVector()) && ValTy.isFloat())
840+
ExtendOp = TargetOpcode::G_FPEXT;
825841
buildCopyToRegs(MIRBuilder, Args[i].Regs, Args[i].OrigRegs[0], OrigTy,
826-
ValTy, extendOpFromFlags(Args[i].Flags[0]));
842+
ValTy, ExtendOp);
827843
}
828844

829845
bool IndirectParameterPassingHandled = false;
@@ -1003,7 +1019,7 @@ void CallLowering::insertSRetLoads(MachineIRBuilder &MIRBuilder, Type *RetTy,
10031019
Align BaseAlign = DL.getPrefTypeAlign(RetTy);
10041020
Type *RetPtrTy =
10051021
PointerType::get(RetTy->getContext(), DL.getAllocaAddrSpace());
1006-
LLT OffsetLLTy = getLLTForType(*DL.getIndexType(RetPtrTy), DL);
1022+
LLT OffsetLLTy = getLLTForType(*DL.getIndexType(RetPtrTy), DL, /* EnableFPInfo */ true);
10071023

10081024
MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(MF, FI);
10091025

@@ -1033,7 +1049,7 @@ void CallLowering::insertSRetStores(MachineIRBuilder &MIRBuilder, Type *RetTy,
10331049
unsigned NumValues = SplitVTs.size();
10341050
Align BaseAlign = DL.getPrefTypeAlign(RetTy);
10351051
unsigned AS = DL.getAllocaAddrSpace();
1036-
LLT OffsetLLTy = getLLTForType(*DL.getIndexType(RetTy->getContext(), AS), DL);
1052+
LLT OffsetLLTy = getLLTForType(*DL.getIndexType(RetTy->getContext(), AS), DL, /* EnableFPInfo */ true);
10371053

10381054
MachinePointerInfo PtrInfo(AS);
10391055

@@ -1291,8 +1307,8 @@ void CallLowering::ValueHandler::copyArgumentMemory(
12911307
Register CallLowering::ValueHandler::extendRegister(Register ValReg,
12921308
const CCValAssign &VA,
12931309
unsigned MaxSizeBits) {
1294-
LLT LocTy{VA.getLocVT()};
1295-
LLT ValTy{VA.getValVT()};
1310+
LLT LocTy(VA.getLocVT(), /* EnableFPInfo */ true);
1311+
LLT ValTy(VA.getValVT(), /* EnableFPInfo */ true);
12961312

12971313
if (LocTy.getSizeInBits() == ValTy.getSizeInBits())
12981314
return ValReg;
@@ -1383,7 +1399,7 @@ static bool isCopyCompatibleType(LLT SrcTy, LLT DstTy) {
13831399
void CallLowering::IncomingValueHandler::assignValueToReg(
13841400
Register ValVReg, Register PhysReg, const CCValAssign &VA) {
13851401
const MVT LocVT = VA.getLocVT();
1386-
const LLT LocTy(LocVT);
1402+
const LLT LocTy(LocVT, true);
13871403
const LLT RegTy = MRI.getType(ValVReg);
13881404

13891405
if (isCopyCompatibleType(RegTy, LocTy)) {

0 commit comments

Comments
 (0)