20
20
#include < stdio.h>
21
21
#include < string.h>
22
22
#include < clBLAS.h>
23
+ #include " mutex.h"
23
24
#include " AutoGemmIncludes/AutoGemmKernelSelection.h"
24
25
#include " GemmSpecialCases.h"
25
26
@@ -126,22 +127,47 @@ static char *getKernelName(cl_kernel clKernel)
126
127
return kernelName;
127
128
}
128
129
130
+ typedef struct kernel_map_key_ {
131
+ cl_context context; // address of context
132
+ cl_device_id device; // address of device
133
+ const char *kernelSource; // address of kernel source
134
+ } kernel_map_key;
135
+
136
+ bool operator <(const kernel_map_key & l, const kernel_map_key & r) {
137
+ if (l.context < r.context ) {
138
+ return true ;
139
+ } else if (r.context < l.context ) {
140
+ return false ;
141
+ }
142
+ if (l.device < r.device ) {
143
+ return true ;
144
+ } else if (r.device < l.device ) {
145
+ return false ;
146
+ }
147
+ if (l.kernelSource < r.kernelSource ) {
148
+ return true ;
149
+ } else if (r.kernelSource < l.kernelSource ) {
150
+ return false ;
151
+ }
152
+ return false ;
153
+ }
154
+
155
+
129
156
/* *****************************************************************************
130
157
* Make Gemm Kernel
131
158
*****************************************************************************/
132
159
// FIXME: This function should be returning an error.
133
160
void makeGemmKernel (
134
- cl_kernel *clKernel,
161
+ cl_kernel *clKernel, // ignored as input; returns as output
135
162
cl_command_queue clQueue,
136
163
const char *kernelSource,
137
164
const char *sourceBuildOptions,
138
165
const unsigned char **kernelBinary,
139
166
size_t *kernelBinarySize,
140
167
const char *binaryBuildOptions)
141
168
{
142
- typedef std::map<std::string, cl_kernel> kernel_map_t ;
143
-
144
- #if defined( _WIN32 )
169
+ typedef std::map<kernel_map_key, cl_kernel> kernel_map_t ;
170
+ #if defined( _WIN32 )
145
171
__declspec ( thread ) static kernel_map_t *kernel_map = 0 ;
146
172
#else
147
173
__thread static kernel_map_t *kernel_map = 0 ;
@@ -159,33 +185,20 @@ void makeGemmKernel(
159
185
err = clGetCommandQueueInfo ( clQueue, CL_QUEUE_DEVICE, sizeof (clDevice), &clDevice, NULL );
160
186
CL_CHECK (err)
161
187
162
- std::stringstream ss;
163
- ss << clDevice << " _" << clContext;
164
- std::string prefix = ss.str ();
165
-
166
- if (*clKernel) {
167
- char *kernelName = getKernelName (*clKernel);
168
- // kernel has already been built, return
169
- #ifdef AUTOGEMM_PRINT_DEBUG
170
- printf (" makeGemmKernel: \" %s\" already built; returning.\n " , kernelName);
171
- #endif
172
-
173
- // Check if kernel exists for this device
174
- std::string key = prefix + " _" + kernelName;
175
- kernel_map_t ::iterator idx = kernel_map->find (key);
176
-
177
-
178
- // If kernel not found for this device, set to NULL
179
- if (idx == kernel_map->end ()) {
180
- *clKernel = NULL ;
181
- } else {
182
- *clKernel = idx->second ;
183
- }
184
-
185
- delete[] kernelName;
188
+ // is kernel already compiled?
189
+ kernel_map_key key;
190
+ key.kernelSource = kernelSource;
191
+ key.context = clContext;
192
+ key.device = clDevice;
193
+ kernel_map_t ::iterator idx = kernel_map->find (key);
194
+ if (idx == kernel_map->end ()) {
195
+ *clKernel = NULL ;
196
+ } else {
197
+ *clKernel = idx->second ;
198
+ return ;
186
199
}
187
200
188
- if (!*clKernel) {
201
+ if (true /* !*clKernel*/ ) { // since kernel wasn't found in map
189
202
// kernel has not been built, so build it (from binary, preferably)
190
203
cl_program clProgram;
191
204
cl_int clBinaryStatus;
@@ -257,17 +270,13 @@ void makeGemmKernel(
257
270
err = clReleaseProgram (clProgram);
258
271
CL_CHECK (err)
259
272
260
- char *kernelName = getKernelName (*clKernel);
261
-
262
273
#ifdef AUTOGEMM_PRINT_DEBUG
263
274
printf (" makeGemmKernel: \" %s\" now built; returning.\n " , kernelName);
264
275
#endif
265
276
266
- std::string key = prefix + " _ " + kernelName;
277
+ // put kernel in map
267
278
(*kernel_map)[key] = *clKernel;
268
- delete[] kernelName;
269
279
}
270
-
271
280
return ;
272
281
}
273
282
@@ -557,6 +566,11 @@ clblasGemm(
557
566
/* *****************************************************************************
558
567
* Build kernels
559
568
*****************************************************************************/
569
+
570
+ tileClKernel = NULL ;
571
+ rowClKernel = NULL ;
572
+ colClKernel = NULL ;
573
+ cornerClKernel = NULL ;
560
574
if (needTileKernel) makeGemmKernel ( tileClKernel, commandQueues[0 ], tileKernelSource, sourceBuildOptions, &tileKernelBinary, tileKernelBinarySize, binaryBuildOptions);
561
575
if (needRowKernel) makeGemmKernel ( rowClKernel, commandQueues[0 ], rowKernelSource, sourceBuildOptions, &rowKernelBinary, rowKernelBinarySize, binaryBuildOptions);
562
576
if (needColKernel) makeGemmKernel ( colClKernel, commandQueues[0 ], colKernelSource, sourceBuildOptions, &colKernelBinary, colKernelBinarySize, binaryBuildOptions);
0 commit comments