@@ -64,6 +64,8 @@ class CommonSPIRTargetCodeGenInfo : public TargetCodeGenInfo {
6464 llvm::Constant *getNullPointer (const CodeGen::CodeGenModule &CGM,
6565 llvm::PointerType *T,
6666 QualType QT) const override ;
67+ void setTargetAttributes (const Decl *D, llvm::GlobalValue *GV,
68+ CodeGen::CodeGenModule &M) const override ;
6769};
6870class SPIRVTargetCodeGenInfo : public CommonSPIRTargetCodeGenInfo {
6971public:
@@ -268,6 +270,22 @@ CommonSPIRTargetCodeGenInfo::getNullPointer(const CodeGen::CodeGenModule &CGM,
268270 llvm::ConstantPointerNull::get (NPT), PT);
269271}
270272
273+ void CommonSPIRTargetCodeGenInfo::setTargetAttributes (
274+ const Decl *D, llvm::GlobalValue *GV, CodeGen::CodeGenModule &M) const {
275+ if (M.getLangOpts ().OpenCL || GV->isDeclaration ())
276+ return ;
277+
278+ const FunctionDecl *FD = dyn_cast<FunctionDecl>(D);
279+ if (!FD)
280+ return ;
281+
282+ llvm::Function *F = dyn_cast<llvm::Function>(GV);
283+ assert (F && " Expected GlobalValue to be a Function" );
284+
285+ if (FD->hasAttr <DeviceKernelAttr>())
286+ F->setCallingConv (getDeviceKernelCallingConv ());
287+ }
288+
271289LangAS
272290SPIRVTargetCodeGenInfo::getGlobalVarAddressSpace (CodeGenModule &CGM,
273291 const VarDecl *D) const {
@@ -292,19 +310,23 @@ SPIRVTargetCodeGenInfo::getGlobalVarAddressSpace(CodeGenModule &CGM,
292310
293311void SPIRVTargetCodeGenInfo::setTargetAttributes (
294312 const Decl *D, llvm::GlobalValue *GV, CodeGen::CodeGenModule &M) const {
295- if (!M.getLangOpts ().HIP ||
296- M.getTarget ().getTriple ().getVendor () != llvm::Triple::AMD)
297- return ;
298313 if (GV->isDeclaration ())
299314 return ;
300315
301- auto F = dyn_cast<llvm::Function>(GV );
302- if (!F )
316+ const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D );
317+ if (!FD )
303318 return ;
304319
305- auto FD = dyn_cast_or_null<FunctionDecl>(D);
306- if (!FD)
320+ llvm::Function *F = dyn_cast<llvm::Function>(GV);
321+ assert (F && " Expected GlobalValue to be a Function" );
322+
323+ if (FD->hasAttr <DeviceKernelAttr>())
324+ F->setCallingConv (getDeviceKernelCallingConv ());
325+
326+ if (!M.getLangOpts ().HIP ||
327+ M.getTarget ().getTriple ().getVendor () != llvm::Triple::AMD)
307328 return ;
329+
308330 if (!FD->hasAttr <CUDAGlobalAttr>())
309331 return ;
310332
0 commit comments