@@ -22,18 +22,18 @@ SPDX-License-Identifier: MIT
2222using namespace llvm ;
2323
2424namespace IGC {
25- bool isTargetExtTy (llvm::Type *Ty) {
25+ bool isTargetExtTy (const llvm::Type *Ty) {
2626#if LLVM_VERSION_MAJOR >= 16
2727 return Ty->isTargetExtTy ();
2828#endif
2929 return false ;
3030}
3131
32- bool isImageBuiltinType (llvm::Type *BuiltinTy) {
32+ bool isImageBuiltinType (const llvm::Type *BuiltinTy) {
3333 if (BuiltinTy->isPointerTy () && !IGCLLVM::isOpaquePointerTy (BuiltinTy))
3434 BuiltinTy = IGCLLVM::getNonOpaquePtrEltTy (BuiltinTy);
3535
36- if (StructType *StructTy = dyn_cast<StructType>(BuiltinTy);
36+ if (const StructType *StructTy = dyn_cast<StructType>(BuiltinTy);
3737 StructTy && StructTy->isOpaque ()) {
3838 StringRef BuiltinName = StructTy->getName ();
3939 llvm::SmallVector<llvm::StringRef, 3 > Buffer;
@@ -51,7 +51,7 @@ bool isImageBuiltinType(llvm::Type *BuiltinTy) {
5151 return true ;
5252 }
5353#if LLVM_VERSION_MAJOR >= 16
54- else if (TargetExtType *ExtTy = dyn_cast<TargetExtType>(BuiltinTy);
54+ else if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(BuiltinTy);
5555 ExtTy && (ExtTy->getName () == " spirv.Image" ||
5656 ExtTy->getName () == " spirv.SampledImage" )) {
5757 return true ;
@@ -61,14 +61,27 @@ bool isImageBuiltinType(llvm::Type *BuiltinTy) {
6161 return false ;
6262}
6363
64- static bool isAnyArgTargetExtTy (const llvm::Function &F) {
65- for (const llvm::Argument &A : F.args ())
66- if (IGC::isTargetExtTy (A.getType ()))
64+ #if LLVM_VERSION_MAJOR >= 16
65+ static bool isNonOpenCLBuiltinType (const llvm::Type *Ty) {
66+ const llvm::TargetExtType *TET = dyn_cast<llvm::TargetExtType>(Ty);
67+ if (!TET)
68+ return false ;
69+
70+ StringRef Name = TET->getTargetExtName ();
71+ return Name.starts_with (" spirv.CooperativeMatrixKHR" ) ||
72+ Name.starts_with (" spirv.JointMatrixINTEL" );
73+ }
74+
75+ static bool isAnyArgOpenCLTargetExtTy (const llvm::Function &F) {
76+ for (const llvm::Argument &A : F.args ()) {
77+ const Type *ArgTy = A.getType ();
78+ if (isTargetExtTy (ArgTy) && !isNonOpenCLBuiltinType (ArgTy))
6779 return true ;
80+ }
81+
6882 return false ;
6983}
7084
71- #if LLVM_VERSION_MAJOR >= 16
7285static Function *
7386cloneFunctionWithPtrArgsInsteadTargetExtTy (Function &F, StringRef NameSuffix) {
7487 Module &M = *F.getParent ();
@@ -78,7 +91,7 @@ cloneFunctionWithPtrArgsInsteadTargetExtTy(Function &F, StringRef NameSuffix) {
7891 ParamTys.reserve (F.arg_size ());
7992 for (Argument &Arg : F.args ()) {
8093 Type *T = Arg.getType ();
81- if (IGC:: isTargetExtTy (T)) {
94+ if (isTargetExtTy (T) && ! isNonOpenCLBuiltinType (T)) {
8295 unsigned AS = 0 ;
8396 auto *TargetExtTy = cast<llvm::TargetExtType>(T);
8497 StringRef TyName = TargetExtTy->getName ();
@@ -162,7 +175,7 @@ static void replaceFunctionAtCallsites(Function &OldF, Function &NewF) {
162175 }
163176}
164177
165- void retypeTargetExtTyArgs (Module *M) {
178+ void retypeOpenCLTargetExtTyArgs (Module *M) {
166179 constexpr StringLiteral TempSuffix = " .__retype_tmp" ;
167180 SmallVector<Function *, 8 > RetypedFuncs;
168181
@@ -171,7 +184,7 @@ void retypeTargetExtTyArgs(Module *M) {
171184 // if (F.isDeclaration())
172185 // continue;
173186
174- if (!isAnyArgTargetExtTy (F))
187+ if (!isAnyArgOpenCLTargetExtTy (F))
175188 continue ;
176189
177190 if (Function *NewF =
0 commit comments