Skip to content

Commit a434d57

Browse files
committed
add draft for InferTypeInfoPass
1 parent 875848b commit a434d57

File tree

6 files changed

+345
-0
lines changed

6 files changed

+345
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#ifndef LLVM_CODEGEN_GLOBALISEL_INFERTYPEINFOPASS_H
2+
#define LLVM_CODEGEN_GLOBALISEL_INFERTYPEINFOPASS_H
3+
4+
#include "llvm/Analysis/AliasAnalysis.h"
5+
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
6+
#include "llvm/CodeGen/MachineFunction.h"
7+
#include "llvm/CodeGen/MachineFunctionPass.h"
8+
9+
namespace llvm {
10+
11+
class InferTypeInfo : public MachineFunctionPass {
12+
public:
13+
static char ID;
14+
15+
private:
16+
MachineRegisterInfo *MRI = nullptr;
17+
MachineFunction *MF = nullptr;
18+
19+
MachineIRBuilder Builder;
20+
21+
/// Initialize the field members using \p MF.
22+
void init(MachineFunction &MF);
23+
24+
public:
25+
InferTypeInfo() : MachineFunctionPass(ID) {}
26+
27+
void getAnalysisUsage(AnalysisUsage &AU) const override;
28+
29+
bool runOnMachineFunction(MachineFunction &MF) override;
30+
31+
private:
32+
bool inferTypeInfo(MachineFunction &MF);
33+
34+
bool shouldBeFP(MachineOperand &Op, unsigned Depth) const;
35+
36+
void updateDef(Register Reg);
37+
38+
void updateUse(MachineOperand &Op, bool FP);
39+
};
40+
41+
} // end namespace llvm
42+
43+
#endif // LLVM_CODEGEN_GLOBALISEL_INFERTYPEINFOPASS_H

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ void initializeHardwareLoopsLegacyPass(PassRegistry &);
131131
void initializeMIRProfileLoaderPassPass(PassRegistry &);
132132
void initializeIRSimilarityIdentifierWrapperPassPass(PassRegistry &);
133133
void initializeIRTranslatorPass(PassRegistry &);
134+
void initializeInferTypeInfoPass(PassRegistry &);
134135
void initializeIVUsersWrapperPassPass(PassRegistry &);
135136
void initializeIfConverterPass(PassRegistry &);
136137
void initializeImmutableModuleSummaryIndexWrapperPassPass(PassRegistry &);

llvm/lib/CodeGen/GlobalISel/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_llvm_component_library(LLVMGlobalISel
1313
GIMatchTableExecutor.cpp
1414
GISelChangeObserver.cpp
1515
IRTranslator.cpp
16+
InferTypeInfoPass.cpp
1617
InlineAsmLowering.cpp
1718
InstructionSelect.cpp
1819
InstructionSelector.cpp

llvm/lib/CodeGen/GlobalISel/GlobalISel.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using namespace llvm;
1616

1717
void llvm::initializeGlobalISel(PassRegistry &Registry) {
1818
initializeIRTranslatorPass(Registry);
19+
initializeInferTypeInfoPass(Registry);
1920
initializeLegalizerPass(Registry);
2021
initializeLoadStoreOptPass(Registry);
2122
initializeLocalizerPass(Registry);
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
//===- llvm/CodeGen/GlobalISel/InferTypeInfoPass.cpp - StripTypeInfoPass ---*-
2+
// C++ -*-==//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
/// \file
10+
/// This file implements the InferTypeInfoPass class.
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "llvm/CodeGen/GlobalISel/InferTypeInfoPass.h"
14+
#include "llvm/ADT/STLExtras.h"
15+
#include "llvm/ADT/SmallSet.h"
16+
#include "llvm/Analysis/AliasAnalysis.h"
17+
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
18+
#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
19+
#include "llvm/CodeGen/GlobalISel/LoadStoreOpt.h"
20+
#include "llvm/CodeGen/GlobalISel/Utils.h"
21+
#include "llvm/CodeGen/MachineBasicBlock.h"
22+
#include "llvm/CodeGen/MachineFrameInfo.h"
23+
#include "llvm/CodeGen/MachineFunction.h"
24+
#include "llvm/CodeGen/MachineInstr.h"
25+
#include "llvm/CodeGen/MachineMemOperand.h"
26+
#include "llvm/CodeGen/MachineOperand.h"
27+
#include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
28+
#include "llvm/CodeGen/MachineRegisterInfo.h"
29+
#include "llvm/CodeGen/Register.h"
30+
#include "llvm/CodeGen/TargetLowering.h"
31+
#include "llvm/CodeGen/TargetOpcodes.h"
32+
#include "llvm/IR/IntrinsicInst.h"
33+
#include "llvm/IR/IntrinsicsAMDGPU.h"
34+
#include "llvm/InitializePasses.h"
35+
36+
#define DEBUG_TYPE "mir-infer-type-info"
37+
38+
using namespace llvm;
39+
40+
char InferTypeInfo::ID = 0;
41+
42+
INITIALIZE_PASS_BEGIN(InferTypeInfo, DEBUG_TYPE, "TODO", false, false)
43+
INITIALIZE_PASS_END(InferTypeInfo, DEBUG_TYPE, "TODO", false, false)
44+
45+
void InferTypeInfo::init(MachineFunction &MF) {
46+
this->MF = &MF;
47+
MRI = &MF.getRegInfo();
48+
Builder.setMF(MF);
49+
}
50+
51+
void InferTypeInfo::getAnalysisUsage(AnalysisUsage &AU) const {
52+
AU.setPreservesAll();
53+
MachineFunctionPass::getAnalysisUsage(AU);
54+
}
55+
56+
static LLT updateType(LLT Ty, bool FP) {
57+
LLT InferredScalarTy =
58+
FP ? LLT::floatingPoint(Ty.getScalarSizeInBits(), LLT::FPInfo::IEEE_FLOAT)
59+
: LLT::integer(Ty.getScalarSizeInBits());
60+
LLT InferredTy =
61+
Ty.isVector() ? Ty.changeElementType(InferredScalarTy) : InferredScalarTy;
62+
63+
return InferredTy;
64+
}
65+
66+
void InferTypeInfo::updateDef(Register Reg) {
67+
LLT Ty = MRI->getType(Reg);
68+
LLT InferredTy = updateType(Ty, false);
69+
70+
MRI->setType(Reg, InferredTy);
71+
}
72+
73+
void InferTypeInfo::updateUse(MachineOperand &Op, bool FP) {
74+
assert(Op.isReg());
75+
LLT Ty = MRI->getType(Op.getReg());
76+
LLT InferredTy = updateType(Ty, FP);
77+
78+
MachineOperand *Def = MRI->getOneDef(Op.getReg());
79+
MachineInstr *MI = Op.getParent();
80+
MachineBasicBlock *MBB = MI->getParent();
81+
82+
Builder.setInsertPt(*MBB, MI);
83+
auto Bitcast = Builder.buildBitcast(InferredTy, Def->getReg());
84+
Op.setReg(Bitcast.getReg(0));
85+
}
86+
87+
constexpr unsigned MaxFPRSearchDepth = 5;
88+
89+
bool InferTypeInfo::shouldBeFP(MachineOperand &Op, unsigned Depth = 0) const {
90+
if (Depth > MaxFPRSearchDepth)
91+
return false;
92+
93+
if (!Op.isReg())
94+
return false;
95+
96+
MachineInstr &MI = *Op.getParent();
97+
98+
auto Pred = [&](MachineOperand &O) { return shouldBeFP(O, Depth + 1); };
99+
100+
// TODO: cache FP registers
101+
102+
switch (MI.getOpcode()) {
103+
// def and use fp instructions
104+
case TargetOpcode::G_FABS:
105+
case TargetOpcode::G_FADD:
106+
case TargetOpcode::G_FCANONICALIZE:
107+
case TargetOpcode::G_FCEIL:
108+
case TargetOpcode::G_FCONSTANT:
109+
case TargetOpcode::G_FCOPYSIGN:
110+
case TargetOpcode::G_FCOS:
111+
case TargetOpcode::G_FDIV:
112+
case TargetOpcode::G_FEXP2:
113+
case TargetOpcode::G_FEXP:
114+
case TargetOpcode::G_FFLOOR:
115+
case TargetOpcode::G_FLOG10:
116+
case TargetOpcode::G_FLOG2:
117+
case TargetOpcode::G_FLOG:
118+
case TargetOpcode::G_FMA:
119+
case TargetOpcode::G_FMAD:
120+
case TargetOpcode::G_FMAXIMUM:
121+
case TargetOpcode::G_FMAXNUM:
122+
case TargetOpcode::G_FMAXNUM_IEEE:
123+
case TargetOpcode::G_FMINIMUM:
124+
case TargetOpcode::G_FMINNUM:
125+
case TargetOpcode::G_FMINNUM_IEEE:
126+
case TargetOpcode::G_FMUL:
127+
case TargetOpcode::G_FNEARBYINT:
128+
case TargetOpcode::G_FNEG:
129+
case TargetOpcode::G_FPEXT:
130+
case TargetOpcode::G_FPOW:
131+
case TargetOpcode::G_FPTRUNC:
132+
case TargetOpcode::G_FREM:
133+
case TargetOpcode::G_FRINT:
134+
case TargetOpcode::G_FSIN:
135+
case TargetOpcode::G_FTAN:
136+
case TargetOpcode::G_FACOS:
137+
case TargetOpcode::G_FASIN:
138+
case TargetOpcode::G_FATAN:
139+
case TargetOpcode::G_FATAN2:
140+
case TargetOpcode::G_FCOSH:
141+
case TargetOpcode::G_FSINH:
142+
case TargetOpcode::G_FTANH:
143+
case TargetOpcode::G_FSQRT:
144+
case TargetOpcode::G_FSUB:
145+
case TargetOpcode::G_INTRINSIC_ROUND:
146+
case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
147+
case TargetOpcode::G_INTRINSIC_TRUNC:
148+
case TargetOpcode::G_VECREDUCE_FADD:
149+
case TargetOpcode::G_VECREDUCE_FMUL:
150+
case TargetOpcode::G_VECREDUCE_FMAX:
151+
case TargetOpcode::G_VECREDUCE_FMIN:
152+
case TargetOpcode::G_VECREDUCE_FMAXIMUM:
153+
case TargetOpcode::G_VECREDUCE_FMINIMUM:
154+
case TargetOpcode::G_VECREDUCE_SEQ_FADD:
155+
case TargetOpcode::G_VECREDUCE_SEQ_FMUL:
156+
return true;
157+
// use only fp instructions
158+
case TargetOpcode::G_SITOFP:
159+
case TargetOpcode::G_UITOFP:
160+
return Op.isDef();
161+
// def only fp instructions
162+
case TargetOpcode::G_FPTOSI:
163+
case TargetOpcode::G_FPTOUI:
164+
case TargetOpcode::G_FPTOSI_SAT:
165+
case TargetOpcode::G_FPTOUI_SAT:
166+
case TargetOpcode::G_FCMP:
167+
case TargetOpcode::G_LROUND:
168+
case TargetOpcode::G_LLROUND:
169+
return Op.isUse();
170+
case TargetOpcode::G_FREEZE:
171+
case TargetOpcode::G_IMPLICIT_DEF:
172+
case TargetOpcode::G_PHI:
173+
case TargetOpcode::G_SELECT:
174+
case TargetOpcode::G_BUILD_VECTOR:
175+
case TargetOpcode::G_CONCAT_VECTORS:
176+
case TargetOpcode::G_INSERT_SUBVECTOR:
177+
case TargetOpcode::G_EXTRACT_SUBVECTOR:
178+
case TargetOpcode::G_SHUFFLE_VECTOR:
179+
case TargetOpcode::G_SPLAT_VECTOR:
180+
case TargetOpcode::G_STEP_VECTOR:
181+
case TargetOpcode::G_VECTOR_COMPRESS: {
182+
return all_of(MI.all_defs(),
183+
[&](MachineOperand &O) {
184+
return all_of(MRI->use_operands(O.getReg()), Pred);
185+
}) &&
186+
all_of(MI.all_uses(), [&](MachineOperand &O) {
187+
return all_of(MRI->def_operands(O.getReg()), Pred);
188+
});
189+
}
190+
case TargetOpcode::G_INSERT_VECTOR_ELT:
191+
case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
192+
MachineOperand &Dst = MI.getOperand(0);
193+
MachineOperand &LHS = MI.getOperand(1);
194+
MachineOperand &RHS = MI.getOperand(2);
195+
196+
return all_of(MRI->use_operands(Dst.getReg()), Pred) &&
197+
(!LHS.isReg() || all_of(MRI->def_operands(LHS.getReg()), Pred)) &&
198+
(!RHS.isReg() || all_of(MRI->def_operands(RHS.getReg()), Pred));
199+
}
200+
case TargetOpcode::G_STORE:
201+
case TargetOpcode::G_INDEXED_STORE: {
202+
MachineOperand &Val = MI.getOperand(0);
203+
return Op.getReg() == Val.getReg() && all_of(MRI->def_operands(Op.getReg()), Pred);
204+
}
205+
case TargetOpcode::G_INDEXED_LOAD:
206+
case TargetOpcode::G_LOAD: {
207+
MachineOperand &Dst = MI.getOperand(0);
208+
return Op.getReg() == Dst.getReg() && all_of(MRI->use_operands(Dst.getReg()), Pred);
209+
}
210+
case TargetOpcode::G_ATOMICRMW_FADD:
211+
case TargetOpcode::G_ATOMICRMW_FSUB:
212+
case TargetOpcode::G_ATOMICRMW_FMAX:
213+
case TargetOpcode::G_ATOMICRMW_FMIN: {
214+
MachineOperand &WriteBack = MI.getOperand(0);
215+
MachineOperand &FPOp = MI.getOperand(2);
216+
return Op.getReg() == WriteBack.getReg() || Op.getReg() == FPOp.getReg();
217+
}
218+
case TargetOpcode::G_INTRINSIC_CONVERGENT:
219+
case TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS:
220+
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
221+
case TargetOpcode::G_INTRINSIC: {
222+
GIntrinsic *Intrinsic = dyn_cast<GIntrinsic>(&MI);
223+
if (!Intrinsic)
224+
return false;
225+
226+
switch (Intrinsic->getIntrinsicID()) {
227+
case Intrinsic::amdgcn_rcp:
228+
case Intrinsic::amdgcn_log:
229+
case Intrinsic::amdgcn_exp2:
230+
case Intrinsic::amdgcn_rsq:
231+
case Intrinsic::amdgcn_sqrt:
232+
case Intrinsic::amdgcn_fdot2_f16_f16:
233+
case Intrinsic::amdgcn_mfma_f32_4x4x4f16:
234+
return true;
235+
default:
236+
return false;
237+
}
238+
return false;
239+
}
240+
default:
241+
break;
242+
}
243+
244+
return false;
245+
}
246+
247+
bool InferTypeInfo::inferTypeInfo(MachineFunction &MF) {
248+
bool Changed = false;
249+
250+
for (MachineBasicBlock &MBB : MF) {
251+
for (MachineInstr &MI : MBB.instrs()) {
252+
253+
for (auto &Def : MI.all_defs()) {
254+
if (shouldBeFP(Def)) {
255+
updateDef(Def.getReg());
256+
Changed |= true;
257+
}
258+
}
259+
260+
for (auto &Use : MI.all_uses()) {
261+
bool IsFPDef =
262+
MRI->getVRegDef(Use.getReg()) &&
263+
all_of(MRI->def_operands(Use.getReg()),
264+
[&](MachineOperand &Op) { return shouldBeFP(Op); });
265+
bool IsFPUse = shouldBeFP(Use);
266+
267+
if (IsFPUse && !IsFPDef) {
268+
updateUse(Use, true);
269+
Changed |= true;
270+
} else if (!IsFPUse && IsFPDef) {
271+
updateUse(Use, false);
272+
Changed |= true;
273+
}
274+
}
275+
276+
for (auto &MemOp: MI.memoperands()) {
277+
bool IsFP = any_of(MI.all_defs(), [&](MachineOperand &O){ return shouldBeFP(O); }) ||
278+
any_of(MI.all_uses(), [&](MachineOperand &O){ return shouldBeFP(O); });
279+
280+
if (!IsFP)
281+
continue;
282+
283+
LLT Ty = MemOp->getType();
284+
LLT NewTy = updateType(Ty, true);
285+
MemOp->setType(NewTy);
286+
}
287+
}
288+
}
289+
290+
return Changed;
291+
}
292+
293+
bool InferTypeInfo::runOnMachineFunction(MachineFunction &MF) {
294+
init(MF);
295+
bool Changed = false;
296+
Changed |= inferTypeInfo(MF);
297+
return Changed;
298+
}

llvm/utils/gn/secondary/llvm/lib/CodeGen/GlobalISel/BUILD.gn

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ static_library("GlobalISel") {
2626
"GISelKnownBits.cpp",
2727
"GlobalISel.cpp",
2828
"IRTranslator.cpp",
29+
"InferTypeInfoPass.cpp",
2930
"InlineAsmLowering.cpp",
3031
"InstructionSelect.cpp",
3132
"InstructionSelector.cpp",

0 commit comments

Comments
 (0)