@@ -179,6 +179,14 @@ def __init__(self):
179
179
self .incStr += "#include <CL/cl.h>\n "
180
180
self .incStr += "#endif\n "
181
181
self .incStr += "\n "
182
+ self .incStr += "#ifdef __cplusplus\n "
183
+ self .incStr += "extern \" C\" {\n "
184
+ self .incStr += "#endif\n "
185
+ self .incStr += " void initAutoGemmClKernels(void);\n " ;
186
+ self .incStr += "#ifdef __cplusplus\n "
187
+ self .incStr += "}\n " ;
188
+ self .incStr += "#endif\n "
189
+ self .incStr += "\n " ;
182
190
183
191
self .cppName = Common .getIncludePath () + "AutoGemmClKernels.cpp"
184
192
self .cppFile = open (self .cppName , "w" )
@@ -190,29 +198,50 @@ def __init__(self):
190
198
self .cppStr += "#endif\n "
191
199
self .cppStr += "\n "
192
200
201
+
202
+ self .initFunction = "" ;
203
+ self .initFunction += "extern \" C\" {\n " ;
204
+ self .initFunction += " void initAutoGemmClKernels(void);\n " ;
205
+ self .initFunction += "}\n " ;
206
+ self .initFunction += "\n " ;
207
+ self .initFunction += "void initAutoGemmClKernels(void) {\n " ;
208
+ self .defines = "" ;
209
+
193
210
def addKernel (self , kernel ):
194
- kernelName = kernel .getName ()
195
- self .incStr += "extern cl_kernel %s_clKernel;\n " % kernelName
196
- self .cppStr += "cl_kernel %s_clKernel = NULL;\n " % kernelName
197
- kernelName = kernel .getRowName ()
198
- self .incStr += "extern cl_kernel %s_clKernel;\n " % kernelName
199
- self .cppStr += "cl_kernel %s_clKernel = NULL;\n " % kernelName
200
- kernelName = kernel .getColName ()
201
- self .incStr += "extern cl_kernel %s_clKernel;\n " % kernelName
202
- self .cppStr += "cl_kernel %s_clKernel = NULL;\n " % kernelName
203
- kernelName = kernel .getCornerName ()
204
- self .incStr += "extern cl_kernel %s_clKernel;\n " % kernelName
205
- self .cppStr += "cl_kernel %s_clKernel = NULL;\n " % kernelName
211
+ kernelNames = [
212
+ kernel .getName (),
213
+ kernel .getRowName (),
214
+ kernel .getColName (),
215
+ kernel .getCornerName ()
216
+ ]
217
+ for kernelName in kernelNames :
218
+ self .incStr += "extern cl_kernel %s_clKernel;\n " % kernelName
219
+
220
+ self .defines += "cl_kernel %s_clKernel = NULL;\n " % kernelName
221
+
222
+ self .initFunction += " if(%s_clKernel != NULL) {\n " % kernelName
223
+ self .initFunction += " clReleaseKernel(%s_clKernel);\n " % kernelName
224
+ self .initFunction += " %s_clKernel = NULL;\n " % kernelName
225
+ self .initFunction += " }\n "
206
226
207
227
self .incFile .write ( self .incStr )
208
228
self .incStr = ""
209
- self .cppFile .write ( self .cppStr )
210
- self .cppStr = ""
229
+ # self.cppFile.write( self.cppStr )
230
+ # self.cppStr = ""
211
231
212
232
def writeToFile (self ):
213
233
self .incFile .write ( self .incStr )
214
234
self .incFile .write ( "\n #endif\n " )
215
235
self .incFile .close ()
236
+
237
+ self .initFunction += "}\n " ;
238
+ self .cppStr += self .defines + "\n " ;
239
+ self .defines = "" ;
240
+ self .cppStr += self .initFunction + "\n " ;
241
+ self .initFunction = "" ;
242
+
243
+ # self.cppStr += "\n";
244
+ # self.cppStr += "initAutoGemmClKernels();\n";
216
245
self .cppFile .write ( self .cppStr )
217
246
self .cppFile .close ()
218
247
0 commit comments