1+ // ===- llvm/CodeGen/GlobalISel/InferTypeInfoPass.cpp - StripTypeInfoPass ---*-
2+ // C++ -*-==//
3+ //
4+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+ // See https://llvm.org/LICENSE.txt for license information.
6+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+ //
8+ // ===----------------------------------------------------------------------===//
9+ // / \file
10+ // / This file implements the InferTypeInfoPass class.
11+ // ===----------------------------------------------------------------------===//
12+
13+ #include " llvm/CodeGen/GlobalISel/InferTypeInfoPass.h"
14+ #include " llvm/ADT/STLExtras.h"
15+ #include " llvm/ADT/SmallSet.h"
16+ #include " llvm/Analysis/AliasAnalysis.h"
17+ #include " llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
18+ #include " llvm/CodeGen/GlobalISel/LegalizerInfo.h"
19+ #include " llvm/CodeGen/GlobalISel/LoadStoreOpt.h"
20+ #include " llvm/CodeGen/GlobalISel/Utils.h"
21+ #include " llvm/CodeGen/MachineBasicBlock.h"
22+ #include " llvm/CodeGen/MachineFrameInfo.h"
23+ #include " llvm/CodeGen/MachineFunction.h"
24+ #include " llvm/CodeGen/MachineInstr.h"
25+ #include " llvm/CodeGen/MachineMemOperand.h"
26+ #include " llvm/CodeGen/MachineOperand.h"
27+ #include " llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
28+ #include " llvm/CodeGen/MachineRegisterInfo.h"
29+ #include " llvm/CodeGen/Register.h"
30+ #include " llvm/CodeGen/TargetLowering.h"
31+ #include " llvm/CodeGen/TargetOpcodes.h"
32+ #include " llvm/IR/IntrinsicInst.h"
33+ #include " llvm/IR/IntrinsicsAMDGPU.h"
34+ #include " llvm/InitializePasses.h"
35+
36+ #define DEBUG_TYPE " mir-infer-type-info"
37+
38+ using namespace llvm ;
39+
40+ char InferTypeInfo::ID = 0 ;
41+
42+ INITIALIZE_PASS_BEGIN (InferTypeInfo, DEBUG_TYPE, " TODO" , false , false )
43+ INITIALIZE_PASS_END(InferTypeInfo, DEBUG_TYPE, " TODO" , false , false )
44+
45+ void InferTypeInfo::init(MachineFunction &MF) {
46+ this ->MF = &MF;
47+ MRI = &MF.getRegInfo ();
48+ Builder.setMF (MF);
49+ }
50+
51+ void InferTypeInfo::getAnalysisUsage (AnalysisUsage &AU) const {
52+ AU.setPreservesAll ();
53+ MachineFunctionPass::getAnalysisUsage (AU);
54+ }
55+
56+ static LLT updateType (LLT Ty, bool FP) {
57+ LLT InferredScalarTy =
58+ FP ? LLT::floatingPoint (Ty.getScalarSizeInBits (), LLT::FPInfo::IEEE_FLOAT)
59+ : LLT::integer (Ty.getScalarSizeInBits ());
60+ LLT InferredTy =
61+ Ty.isVector () ? Ty.changeElementType (InferredScalarTy) : InferredScalarTy;
62+
63+ return InferredTy;
64+ }
65+
66+ void InferTypeInfo::updateDef (Register Reg) {
67+ LLT Ty = MRI->getType (Reg);
68+ LLT InferredTy = updateType (Ty, false );
69+
70+ MRI->setType (Reg, InferredTy);
71+ }
72+
73+ void InferTypeInfo::updateUse (MachineOperand &Op, bool FP) {
74+ assert (Op.isReg ());
75+ LLT Ty = MRI->getType (Op.getReg ());
76+ LLT InferredTy = updateType (Ty, FP);
77+
78+ MachineOperand *Def = MRI->getOneDef (Op.getReg ());
79+ MachineInstr *MI = Op.getParent ();
80+ MachineBasicBlock *MBB = MI->getParent ();
81+
82+ Builder.setInsertPt (*MBB, MI);
83+ auto Bitcast = Builder.buildBitcast (InferredTy, Def->getReg ());
84+ Op.setReg (Bitcast.getReg (0 ));
85+ }
86+
87+ constexpr unsigned MaxFPRSearchDepth = 5 ;
88+
89+ bool InferTypeInfo::shouldBeFP (MachineOperand &Op, unsigned Depth = 0 ) const {
90+ if (Depth > MaxFPRSearchDepth)
91+ return false ;
92+
93+ if (!Op.isReg ())
94+ return false ;
95+
96+ MachineInstr &MI = *Op.getParent ();
97+
98+ auto Pred = [&](MachineOperand &O) { return shouldBeFP (O, Depth + 1 ); };
99+
100+ // TODO: cache FP registers
101+
102+ switch (MI.getOpcode ()) {
103+ // def and use fp instructions
104+ case TargetOpcode::G_FABS:
105+ case TargetOpcode::G_FADD:
106+ case TargetOpcode::G_FCANONICALIZE:
107+ case TargetOpcode::G_FCEIL:
108+ case TargetOpcode::G_FCONSTANT:
109+ case TargetOpcode::G_FCOPYSIGN:
110+ case TargetOpcode::G_FCOS:
111+ case TargetOpcode::G_FDIV:
112+ case TargetOpcode::G_FEXP2:
113+ case TargetOpcode::G_FEXP:
114+ case TargetOpcode::G_FFLOOR:
115+ case TargetOpcode::G_FLOG10:
116+ case TargetOpcode::G_FLOG2:
117+ case TargetOpcode::G_FLOG:
118+ case TargetOpcode::G_FMA:
119+ case TargetOpcode::G_FMAD:
120+ case TargetOpcode::G_FMAXIMUM:
121+ case TargetOpcode::G_FMAXNUM:
122+ case TargetOpcode::G_FMAXNUM_IEEE:
123+ case TargetOpcode::G_FMINIMUM:
124+ case TargetOpcode::G_FMINNUM:
125+ case TargetOpcode::G_FMINNUM_IEEE:
126+ case TargetOpcode::G_FMUL:
127+ case TargetOpcode::G_FNEARBYINT:
128+ case TargetOpcode::G_FNEG:
129+ case TargetOpcode::G_FPEXT:
130+ case TargetOpcode::G_FPOW:
131+ case TargetOpcode::G_FPTRUNC:
132+ case TargetOpcode::G_FREM:
133+ case TargetOpcode::G_FRINT:
134+ case TargetOpcode::G_FSIN:
135+ case TargetOpcode::G_FTAN:
136+ case TargetOpcode::G_FACOS:
137+ case TargetOpcode::G_FASIN:
138+ case TargetOpcode::G_FATAN:
139+ case TargetOpcode::G_FATAN2:
140+ case TargetOpcode::G_FCOSH:
141+ case TargetOpcode::G_FSINH:
142+ case TargetOpcode::G_FTANH:
143+ case TargetOpcode::G_FSQRT:
144+ case TargetOpcode::G_FSUB:
145+ case TargetOpcode::G_INTRINSIC_ROUND:
146+ case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
147+ case TargetOpcode::G_INTRINSIC_TRUNC:
148+ case TargetOpcode::G_VECREDUCE_FADD:
149+ case TargetOpcode::G_VECREDUCE_FMUL:
150+ case TargetOpcode::G_VECREDUCE_FMAX:
151+ case TargetOpcode::G_VECREDUCE_FMIN:
152+ case TargetOpcode::G_VECREDUCE_FMAXIMUM:
153+ case TargetOpcode::G_VECREDUCE_FMINIMUM:
154+ case TargetOpcode::G_VECREDUCE_SEQ_FADD:
155+ case TargetOpcode::G_VECREDUCE_SEQ_FMUL:
156+ return true ;
157+ // use only fp instructions
158+ case TargetOpcode::G_SITOFP:
159+ case TargetOpcode::G_UITOFP:
160+ return Op.isDef ();
161+ // def only fp instructions
162+ case TargetOpcode::G_FPTOSI:
163+ case TargetOpcode::G_FPTOUI:
164+ case TargetOpcode::G_FPTOSI_SAT:
165+ case TargetOpcode::G_FPTOUI_SAT:
166+ case TargetOpcode::G_FCMP:
167+ case TargetOpcode::G_LROUND:
168+ case TargetOpcode::G_LLROUND:
169+ return Op.isUse ();
170+ case TargetOpcode::G_FREEZE:
171+ case TargetOpcode::G_IMPLICIT_DEF:
172+ case TargetOpcode::G_PHI:
173+ case TargetOpcode::G_SELECT:
174+ case TargetOpcode::G_BUILD_VECTOR:
175+ case TargetOpcode::G_CONCAT_VECTORS:
176+ case TargetOpcode::G_INSERT_SUBVECTOR:
177+ case TargetOpcode::G_EXTRACT_SUBVECTOR:
178+ case TargetOpcode::G_SHUFFLE_VECTOR:
179+ case TargetOpcode::G_SPLAT_VECTOR:
180+ case TargetOpcode::G_STEP_VECTOR:
181+ case TargetOpcode::G_VECTOR_COMPRESS: {
182+ return all_of (MI.all_defs (),
183+ [&](MachineOperand &O) {
184+ return all_of (MRI->use_operands (O.getReg ()), Pred);
185+ }) &&
186+ all_of (MI.all_uses (), [&](MachineOperand &O) {
187+ return all_of (MRI->def_operands (O.getReg ()), Pred);
188+ });
189+ }
190+ case TargetOpcode::G_INSERT_VECTOR_ELT:
191+ case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
192+ MachineOperand &Dst = MI.getOperand (0 );
193+ MachineOperand &LHS = MI.getOperand (1 );
194+ MachineOperand &RHS = MI.getOperand (2 );
195+
196+ return all_of (MRI->use_operands (Dst.getReg ()), Pred) &&
197+ (!LHS.isReg () || all_of (MRI->def_operands (LHS.getReg ()), Pred)) &&
198+ (!RHS.isReg () || all_of (MRI->def_operands (RHS.getReg ()), Pred));
199+ }
200+ case TargetOpcode::G_STORE:
201+ case TargetOpcode::G_INDEXED_STORE: {
202+ MachineOperand &Val = MI.getOperand (0 );
203+ return Op.getReg () == Val.getReg () && all_of (MRI->def_operands (Op.getReg ()), Pred);
204+ }
205+ case TargetOpcode::G_INDEXED_LOAD:
206+ case TargetOpcode::G_LOAD: {
207+ MachineOperand &Dst = MI.getOperand (0 );
208+ return Op.getReg () == Dst.getReg () && all_of (MRI->use_operands (Dst.getReg ()), Pred);
209+ }
210+ case TargetOpcode::G_ATOMICRMW_FADD:
211+ case TargetOpcode::G_ATOMICRMW_FSUB:
212+ case TargetOpcode::G_ATOMICRMW_FMAX:
213+ case TargetOpcode::G_ATOMICRMW_FMIN: {
214+ MachineOperand &WriteBack = MI.getOperand (0 );
215+ MachineOperand &FPOp = MI.getOperand (2 );
216+ return Op.getReg () == WriteBack.getReg () || Op.getReg () == FPOp.getReg ();
217+ }
218+ case TargetOpcode::G_INTRINSIC_CONVERGENT:
219+ case TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS:
220+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
221+ case TargetOpcode::G_INTRINSIC: {
222+ GIntrinsic *Intrinsic = dyn_cast<GIntrinsic>(&MI);
223+ if (!Intrinsic)
224+ return false ;
225+
226+ switch (Intrinsic->getIntrinsicID ()) {
227+ case Intrinsic::amdgcn_rcp:
228+ case Intrinsic::amdgcn_log:
229+ case Intrinsic::amdgcn_exp2:
230+ case Intrinsic::amdgcn_rsq:
231+ case Intrinsic::amdgcn_sqrt:
232+ case Intrinsic::amdgcn_fdot2_f16_f16:
233+ case Intrinsic::amdgcn_mfma_f32_4x4x4f16:
234+ return true ;
235+ default :
236+ return false ;
237+ }
238+ return false ;
239+ }
240+ default :
241+ break ;
242+ }
243+
244+ return false ;
245+ }
246+
247+ bool InferTypeInfo::inferTypeInfo (MachineFunction &MF) {
248+ bool Changed = false ;
249+
250+ for (MachineBasicBlock &MBB : MF) {
251+ for (MachineInstr &MI : MBB.instrs ()) {
252+
253+ for (auto &Def : MI.all_defs ()) {
254+ if (shouldBeFP (Def)) {
255+ updateDef (Def.getReg ());
256+ Changed |= true ;
257+ }
258+ }
259+
260+ for (auto &Use : MI.all_uses ()) {
261+ bool IsFPDef =
262+ MRI->getVRegDef (Use.getReg ()) &&
263+ all_of (MRI->def_operands (Use.getReg ()),
264+ [&](MachineOperand &Op) { return shouldBeFP (Op); });
265+ bool IsFPUse = shouldBeFP (Use);
266+
267+ if (IsFPUse && !IsFPDef) {
268+ updateUse (Use, true );
269+ Changed |= true ;
270+ } else if (!IsFPUse && IsFPDef) {
271+ updateUse (Use, false );
272+ Changed |= true ;
273+ }
274+ }
275+
276+ for (auto &MemOp: MI.memoperands ()) {
277+ bool IsFP = any_of (MI.all_defs (), [&](MachineOperand &O){ return shouldBeFP (O); }) ||
278+ any_of (MI.all_uses (), [&](MachineOperand &O){ return shouldBeFP (O); });
279+
280+ if (!IsFP)
281+ continue ;
282+
283+ LLT Ty = MemOp->getType ();
284+ LLT NewTy = updateType (Ty, true );
285+ MemOp->setType (NewTy);
286+ }
287+ }
288+ }
289+
290+ return Changed;
291+ }
292+
293+ bool InferTypeInfo::runOnMachineFunction (MachineFunction &MF) {
294+ init (MF);
295+ bool Changed = false ;
296+ Changed |= inferTypeInfo (MF);
297+ return Changed;
298+ }
0 commit comments