1515#include " llvm/CodeGen/ReplaceWithVeclib.h"
1616#include " llvm/ADT/STLExtras.h"
1717#include " llvm/ADT/Statistic.h"
18+ #include " llvm/ADT/StringRef.h"
1819#include " llvm/Analysis/DemandedBits.h"
1920#include " llvm/Analysis/GlobalsModRef.h"
2021#include " llvm/Analysis/OptimizationRemarkEmitter.h"
2122#include " llvm/Analysis/TargetLibraryInfo.h"
2223#include " llvm/Analysis/VectorUtils.h"
2324#include " llvm/CodeGen/Passes.h"
25+ #include " llvm/IR/DerivedTypes.h"
2426#include " llvm/IR/IRBuilder.h"
2527#include " llvm/IR/InstIterator.h"
28+ #include " llvm/Support/TypeSize.h"
2629#include " llvm/Transforms/Utils/ModuleUtils.h"
2730
2831using namespace llvm ;
@@ -38,138 +41,137 @@ STATISTIC(NumTLIFuncDeclAdded,
3841STATISTIC (NumFuncUsedAdded,
3942 " Number of functions added to `llvm.compiler.used`" );
4043
41- static bool replaceWithTLIFunction (CallInst &CI, const StringRef TLIName) {
42- Module *M = CI.getModule ();
43-
44- Function *OldFunc = CI.getCalledFunction ();
45-
46- // Check if the vector library function is already declared in this module,
47- // otherwise insert it.
44+ // / Returns a vector Function that it adds to the Module \p M. When an \p
45+ // / ScalarFunc is not null, it copies its attributes to the newly created
46+ // / Function.
47+ Function *getTLIFunction (Module *M, FunctionType *VectorFTy,
48+ const StringRef TLIName,
49+ Function *ScalarFunc = nullptr ) {
4850 Function *TLIFunc = M->getFunction (TLIName);
4951 if (!TLIFunc) {
50- TLIFunc = Function::Create (OldFunc->getFunctionType (),
51- Function::ExternalLinkage, TLIName, *M);
52- TLIFunc->copyAttributesFrom (OldFunc);
52+ TLIFunc =
53+ Function::Create (VectorFTy, Function::ExternalLinkage, TLIName, *M);
54+ if (ScalarFunc)
55+ TLIFunc->copyAttributesFrom (ScalarFunc);
5356
5457 LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Added vector library function `"
5558 << TLIName << " ` of type `" << *(TLIFunc->getType ())
5659 << " ` to module.\n " );
5760
5861 ++NumTLIFuncDeclAdded;
59-
60- // Add the freshly created function to llvm.compiler.used,
61- // similar to as it is done in InjectTLIMappings
62+ // Add the freshly created function to llvm.compiler.used, similar to as it
63+ // is done in InjectTLIMappings.
6264 appendToCompilerUsed (*M, {TLIFunc});
63-
6465 LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Adding `" << TLIName
6566 << " ` to `@llvm.compiler.used`.\n " );
6667 ++NumFuncUsedAdded;
6768 }
69+ return TLIFunc;
70+ }
6871
69- // Replace the call to the vector intrinsic with a call
70- // to the corresponding function from the vector library.
71- IRBuilder<> IRBuilder (&CI);
72- SmallVector<Value *> Args (CI.args ());
73- // Preserve the operand bundles.
74- SmallVector<OperandBundleDef, 1 > OpBundles;
75- CI.getOperandBundlesAsDefs (OpBundles);
76- CallInst *Replacement = IRBuilder.CreateCall (TLIFunc, Args, OpBundles);
77- assert (OldFunc->getFunctionType () == TLIFunc->getFunctionType () &&
78- " Expecting function types to be identical" );
79- CI.replaceAllUsesWith (Replacement);
80- if (isa<FPMathOperator>(Replacement)) {
81- // Preserve fast math flags for FP math.
82- Replacement->copyFastMathFlags (&CI);
72+ // / Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
73+ // / the corresponding function from the vector library ( \p TLIVecFunc ).
74+ static void replaceWithTLIFunction (CallInst &CalltoReplace, VFInfo &Info,
75+ Function *TLIVecFunc) {
76+ IRBuilder<> IRBuilder (&CalltoReplace);
77+ SmallVector<Value *> Args (CalltoReplace.args ());
78+ if (auto OptMaskpos = Info.getParamIndexForOptionalMask ()) {
79+ auto *MaskTy = VectorType::get (Type::getInt1Ty (CalltoReplace.getContext ()),
80+ Info.Shape .VF );
81+ Args.insert (Args.begin () + OptMaskpos.value (),
82+ Constant::getAllOnesValue (MaskTy));
8383 }
8484
85- LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `"
86- << OldFunc->getName () << " ` with call to `" << TLIName
87- << " `.\n " );
88- ++NumCallsReplaced;
89- return true ;
85+ // Preserve the operand bundles.
86+ SmallVector<OperandBundleDef, 1 > OpBundles;
87+ CalltoReplace.getOperandBundlesAsDefs (OpBundles);
88+ CallInst *Replacement = IRBuilder.CreateCall (TLIVecFunc, Args, OpBundles);
89+ CalltoReplace.replaceAllUsesWith (Replacement);
90+ // Preserve fast math flags for FP math.
91+ if (isa<FPMathOperator>(Replacement))
92+ Replacement->copyFastMathFlags (&CalltoReplace);
9093}
9194
95+ // / Returns true when successfully replaced \p CallToReplace with a suitable
96+ // / function taking vector arguments, based on available mappings in the \p TLI.
97+ // / Currently only works when \p CallToReplace is a call to vectorized
98+ // / intrinsic.
9299static bool replaceWithCallToVeclib (const TargetLibraryInfo &TLI,
93- CallInst &CI ) {
94- if (!CI .getCalledFunction ()) {
100+ CallInst &CallToReplace ) {
101+ if (!CallToReplace .getCalledFunction ())
95102 return false ;
96- }
97103
98- auto IntrinsicID = CI .getCalledFunction ()->getIntrinsicID ();
99- if (IntrinsicID == Intrinsic::not_intrinsic) {
100- // Replacement is only performed for intrinsic functions
104+ auto IntrinsicID = CallToReplace .getCalledFunction ()->getIntrinsicID ();
105+ // Replacement is only performed for intrinsic functions.
106+ if (IntrinsicID == Intrinsic::not_intrinsic)
101107 return false ;
102- }
103108
104- // Convert vector arguments to scalar type and check that
105- // all vector operands have identical vector width .
109+ // Compute arguments types of the corresponding scalar call. Additionally
110+ // checks if in the vector call, all vector operands have the same EC .
106111 ElementCount VF = ElementCount::getFixed (0 );
107- SmallVector<Type *> ScalarTypes;
108- for (auto Arg : enumerate(CI.args ())) {
109- auto *ArgType = Arg.value ()->getType ();
110- // Vector calls to intrinsics can still have
111- // scalar operands for specific arguments.
112+ SmallVector<Type *> ScalarArgTypes;
113+ for (auto Arg : enumerate(CallToReplace.args ())) {
114+ auto *ArgTy = Arg.value ()->getType ();
112115 if (isVectorIntrinsicWithScalarOpAtArg (IntrinsicID, Arg.index ())) {
113- ScalarTypes.push_back (ArgType);
114- } else {
115- // The argument in this place should be a vector if
116- // this is a call to a vector intrinsic.
117- auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
118- if (!VectorArgTy) {
119- // The argument is not a vector, do not perform
120- // the replacement.
121- return false ;
122- }
123- ElementCount NumElements = VectorArgTy->getElementCount ();
124- if (NumElements.isScalable ()) {
125- // The current implementation does not support
126- // scalable vectors.
116+ ScalarArgTypes.push_back (ArgTy);
117+ } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
118+ ScalarArgTypes.push_back (ArgTy->getScalarType ());
119+ // Disallow vector arguments with different VFs. When processing the first
120+ // vector argument, store it's VF, and for the rest ensure that they match
121+ // it.
122+ if (VF.isZero ())
123+ VF = VectorArgTy->getElementCount ();
124+ else if (VF != VectorArgTy->getElementCount ())
127125 return false ;
128- }
129- if (VF.isNonZero () && VF != NumElements) {
130- // The different arguments differ in vector size.
131- return false ;
132- } else {
133- VF = NumElements;
134- }
135- ScalarTypes.push_back (VectorArgTy->getElementType ());
136- }
126+ } else
127+ // Exit when it is supposed to be a vector argument but it isn't.
128+ return false ;
137129 }
138130
139- // Try to reconstruct the name for the scalar version of this
140- // intrinsic using the intrinsic ID and the argument types
141- // converted to scalar above.
142- std::string ScalarName;
143- if (Intrinsic::isOverloaded (IntrinsicID)) {
144- ScalarName = Intrinsic::getName (IntrinsicID, ScalarTypes, CI.getModule ());
145- } else {
146- ScalarName = Intrinsic::getName (IntrinsicID).str ();
147- }
131+ // Try to reconstruct the name for the scalar version of this intrinsic using
132+ // the intrinsic ID and the argument types converted to scalar above.
133+ std::string ScalarName =
134+ (Intrinsic::isOverloaded (IntrinsicID)
135+ ? Intrinsic::getName (IntrinsicID, ScalarArgTypes,
136+ CallToReplace.getModule ())
137+ : Intrinsic::getName (IntrinsicID).str ());
138+
139+ // Try to find the mapping for the scalar version of this intrinsic and the
140+ // exact vector width of the call operands in the TargetLibraryInfo. First,
141+ // check with a non-masked variant, and if that fails try with a masked one.
142+ const VecDesc *VD =
143+ TLI.getVectorMappingInfo (ScalarName, VF, /* Masked*/ false );
144+ if (!VD && !(VD = TLI.getVectorMappingInfo (ScalarName, VF, /* Masked*/ true )))
145+ return false ;
148146
149- if (!TLI.isFunctionVectorizable (ScalarName)) {
150- // The TargetLibraryInfo does not contain a vectorized version of
151- // the scalar function.
147+ LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Found TLI mapping from: `" << ScalarName
148+ << " ` and vector width " << VF << " to: `"
149+ << VD->getVectorFnName () << " `.\n " );
150+
151+ // Replace the call to the intrinsic with a call to the vector library
152+ // function.
153+ Type *ScalarRetTy = CallToReplace.getType ()->getScalarType ();
154+ FunctionType *ScalarFTy =
155+ FunctionType::get (ScalarRetTy, ScalarArgTypes, /* isVarArg*/ false );
156+ const std::string MangledName = VD->getVectorFunctionABIVariantString ();
157+ auto OptInfo = VFABI::tryDemangleForVFABI (MangledName, ScalarFTy);
158+ if (!OptInfo)
152159 return false ;
153- }
154160
155- // Try to find the mapping for the scalar version of this intrinsic
156- // and the exact vector width of the call operands in the
157- // TargetLibraryInfo.
158- StringRef TLIName = TLI.getVectorizedFunction (ScalarName, VF);
159-
160- LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Looking up TLI mapping for `"
161- << ScalarName << " ` and vector width " << VF << " .\n " );
162-
163- if (!TLIName.empty ()) {
164- // Found the correct mapping in the TargetLibraryInfo,
165- // replace the call to the intrinsic with a call to
166- // the vector library function.
167- LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Found TLI function `" << TLIName
168- << " `.\n " );
169- return replaceWithTLIFunction (CI, TLIName);
170- }
161+ FunctionType *VectorFTy = VFABI::createFunctionType (*OptInfo, ScalarFTy);
162+ if (!VectorFTy)
163+ return false ;
164+
165+ Function *FuncToReplace = CallToReplace.getCalledFunction ();
166+ Function *TLIFunc = getTLIFunction (CallToReplace.getModule (), VectorFTy,
167+ VD->getVectorFnName (), FuncToReplace);
168+ replaceWithTLIFunction (CallToReplace, *OptInfo, TLIFunc);
171169
172- return false ;
170+ LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `"
171+ << FuncToReplace->getName () << " ` with call to `"
172+ << TLIFunc->getName () << " `.\n " );
173+ ++NumCallsReplaced;
174+ return true ;
173175}
174176
175177static bool runImpl (const TargetLibraryInfo &TLI, Function &F) {
@@ -185,9 +187,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
185187 }
186188 // Erase the calls to the intrinsics that have been replaced
187189 // with calls to the vector library.
188- for (auto *CI : ReplacedCalls) {
190+ for (auto *CI : ReplacedCalls)
189191 CI->eraseFromParent ();
190- }
191192 return Changed;
192193}
193194
@@ -207,10 +208,10 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
207208 PA.preserve <DemandedBitsAnalysis>();
208209 PA.preserve <OptimizationRemarkEmitterAnalysis>();
209210 return PA;
210- } else {
211- // The pass did not replace any calls, hence it preserves all analyses.
212- return PreservedAnalyses::all ();
213211 }
212+
213+ // The pass did not replace any calls, hence it preserves all analyses.
214+ return PreservedAnalyses::all ();
214215}
215216
216217// //////////////////////////////////////////////////////////////////////////////
0 commit comments