20
20
#include < stdio.h>
21
21
#include < string.h>
22
22
#include < clBLAS.h>
23
+ #include " mutex.h"
23
24
#include " AutoGemmIncludes/AutoGemmKernelSelection.h"
24
25
#include " GemmSpecialCases.h"
25
26
26
27
#include < functor.h>
27
28
// #include <functor_selector.h>
28
29
#include " xgemm.h"
29
30
31
+ #ifdef _WIN32
32
+ // #include <thread>
33
+ #else
34
+ #include < pthread.h>
35
+ #endif
36
+
30
37
/* *****************************************************************************
31
38
* Row major -> column major
32
39
*****************************************************************************/
@@ -120,22 +127,54 @@ static char *getKernelName(cl_kernel clKernel)
120
127
return kernelName;
121
128
}
122
129
130
+ typedef struct kernel_map_key_ {
131
+ cl_context context; // address of context
132
+ cl_device_id device; // address of device
133
+ const char *kernelSource; // address of kernel source
134
+ } kernel_map_key;
135
+
136
+ bool operator <(const kernel_map_key & l, const kernel_map_key & r) {
137
+ if (l.context < r.context ) {
138
+ return true ;
139
+ } else if (r.context < l.context ) {
140
+ return false ;
141
+ }
142
+ if (l.device < r.device ) {
143
+ return true ;
144
+ } else if (r.device < l.device ) {
145
+ return false ;
146
+ }
147
+ if (l.kernelSource < r.kernelSource ) {
148
+ return true ;
149
+ } else if (r.kernelSource < l.kernelSource ) {
150
+ return false ;
151
+ }
152
+ return false ;
153
+ }
154
+
155
+
123
156
/* *****************************************************************************
124
157
* Make Gemm Kernel
125
158
*****************************************************************************/
126
159
// FIXME: This function should be returning an error.
127
160
void makeGemmKernel (
128
- cl_kernel *clKernel,
161
+ cl_kernel *clKernel, // ignored as input; returns as output only
129
162
cl_command_queue clQueue,
130
163
const char *kernelSource,
131
164
const char *sourceBuildOptions,
132
165
const unsigned char **kernelBinary,
133
166
size_t *kernelBinarySize,
134
167
const char *binaryBuildOptions)
135
168
{
136
- // TODO: This will need to be converted to thread local when making clBLAS thread safe
137
- typedef std::map<std::string, cl_kernel> kernel_map_t ;
138
- static kernel_map_t kernel_map;
169
+ typedef std::map<kernel_map_key, cl_kernel> kernel_map_t ;
170
+ #if defined( _WIN32 )
171
+ __declspec ( thread ) static kernel_map_t *kernel_map = 0 ;
172
+ #else
173
+ __thread static kernel_map_t *kernel_map = 0 ;
174
+ #endif
175
+ if (!kernel_map) {
176
+ kernel_map = new kernel_map_t ();
177
+ }
139
178
140
179
cl_context clContext;
141
180
cl_device_id clDevice;
@@ -146,33 +185,20 @@ void makeGemmKernel(
146
185
err = clGetCommandQueueInfo ( clQueue, CL_QUEUE_DEVICE, sizeof (clDevice), &clDevice, NULL );
147
186
CL_CHECK (err)
148
187
149
- std::stringstream ss;
150
- ss << clDevice << " _" << clContext;
151
- std::string prefix = ss.str ();
152
-
153
- if (*clKernel) {
154
- char *kernelName = getKernelName (*clKernel);
155
- // kernel has already been built, return
156
- #ifdef AUTOGEMM_PRINT_DEBUG
157
- printf (" makeGemmKernel: \" %s\" already built; returning.\n " , kernelName);
158
- #endif
159
-
160
- // Check if kernel exists for this device
161
- std::string key = prefix + " _" + kernelName;
162
- kernel_map_t ::iterator idx = kernel_map.find (key);
163
-
164
-
165
- // If kernel not found for this device, set to NULL
166
- if (idx == kernel_map.end ()) {
167
- *clKernel = NULL ;
168
- } else {
169
- *clKernel = idx->second ;
170
- }
171
-
172
- delete[] kernelName;
188
+ // is kernel already compiled?
189
+ kernel_map_key key;
190
+ key.kernelSource = kernelSource;
191
+ key.context = clContext;
192
+ key.device = clDevice;
193
+ kernel_map_t ::iterator idx = kernel_map->find (key);
194
+ if (idx == kernel_map->end ()) {
195
+ *clKernel = NULL ;
196
+ } else {
197
+ *clKernel = idx->second ;
198
+ return ;
173
199
}
174
200
175
- if (!*clKernel) {
201
+ if (true /* !*clKernel*/ ) { // since kernel wasn't found in map
176
202
// kernel has not been built, so build it (from binary, preferably)
177
203
cl_program clProgram;
178
204
cl_int clBinaryStatus;
@@ -244,17 +270,13 @@ void makeGemmKernel(
244
270
err = clReleaseProgram (clProgram);
245
271
CL_CHECK (err)
246
272
247
- char *kernelName = getKernelName (*clKernel);
248
-
249
273
#ifdef AUTOGEMM_PRINT_DEBUG
250
274
printf (" makeGemmKernel: \" %s\" now built; returning.\n " , kernelName);
251
275
#endif
252
276
253
- std::string key = prefix + " _" + kernelName;
254
- kernel_map[key] = *clKernel;
255
- delete[] kernelName;
277
+ // put kernel in map
278
+ (*kernel_map)[key] = *clKernel;
256
279
}
257
-
258
280
return ;
259
281
}
260
282
@@ -439,10 +461,10 @@ clblasGemm(
439
461
size_t *colKernelBinarySize = 0 ;
440
462
size_t *cornerKernelBinarySize = 0 ;
441
463
const char *binaryBuildOptions = NULL ;
442
- cl_kernel *tileClKernel = NULL ;
443
- cl_kernel *rowClKernel = NULL ;
444
- cl_kernel *colClKernel = NULL ;
445
- cl_kernel *cornerClKernel = NULL ;
464
+ cl_kernel *tileClKernelDummy = NULL ; // no longer used; broke thread safety
465
+ cl_kernel *rowClKernelDummy = NULL ; // no longer used; broke thread safety
466
+ cl_kernel *colClKernelDummy = NULL ; // no longer used; broke thread safety
467
+ cl_kernel *cornerClKernelDummy = NULL ; // no longer used; broke thread safety
446
468
unsigned int workGroupNumRows;
447
469
unsigned int workGroupNumCols;
448
470
unsigned int microTileNumRows;
@@ -467,10 +489,10 @@ clblasGemm(
467
489
&colKernelBinarySize,
468
490
&cornerKernelBinarySize,
469
491
&binaryBuildOptions,
470
- &tileClKernel ,
471
- &rowClKernel ,
472
- &colClKernel ,
473
- &cornerClKernel ,
492
+ &tileClKernelDummy ,
493
+ &rowClKernelDummy ,
494
+ &colClKernelDummy ,
495
+ &cornerClKernelDummy ,
474
496
&workGroupNumRows,
475
497
&workGroupNumCols,
476
498
µTileNumRows,
@@ -508,10 +530,10 @@ clblasGemm(
508
530
&colKernelBinarySize,
509
531
&cornerKernelBinarySize,
510
532
&binaryBuildOptions,
511
- &tileClKernel ,
512
- &rowClKernel ,
513
- &colClKernel ,
514
- &cornerClKernel ,
533
+ &tileClKernelDummy ,
534
+ &rowClKernelDummy ,
535
+ &colClKernelDummy ,
536
+ &cornerClKernelDummy ,
515
537
&workGroupNumRows,
516
538
&workGroupNumCols,
517
539
µTileNumRows,
@@ -544,10 +566,16 @@ clblasGemm(
544
566
/* *****************************************************************************
545
567
* Build kernels
546
568
*****************************************************************************/
547
- if (needTileKernel) makeGemmKernel ( tileClKernel, commandQueues[0 ], tileKernelSource, sourceBuildOptions, &tileKernelBinary, tileKernelBinarySize, binaryBuildOptions);
548
- if (needRowKernel) makeGemmKernel ( rowClKernel, commandQueues[0 ], rowKernelSource, sourceBuildOptions, &rowKernelBinary, rowKernelBinarySize, binaryBuildOptions);
549
- if (needColKernel) makeGemmKernel ( colClKernel, commandQueues[0 ], colKernelSource, sourceBuildOptions, &colKernelBinary, colKernelBinarySize, binaryBuildOptions);
550
- if (needCornerKernel) makeGemmKernel (cornerClKernel, commandQueues[0 ], cornerKernelSource, sourceBuildOptions, &cornerKernelBinary, cornerKernelBinarySize, binaryBuildOptions);
569
+
570
+
571
+ cl_kernel tileClKernel = NULL ;
572
+ cl_kernel rowClKernel = NULL ;
573
+ cl_kernel colClKernel = NULL ;
574
+ cl_kernel cornerClKernel = NULL ;
575
+ if (needTileKernel) makeGemmKernel ( &tileClKernel, commandQueues[0 ], tileKernelSource, sourceBuildOptions, &tileKernelBinary, tileKernelBinarySize, binaryBuildOptions);
576
+ if (needRowKernel) makeGemmKernel ( &rowClKernel, commandQueues[0 ], rowKernelSource, sourceBuildOptions, &rowKernelBinary, rowKernelBinarySize, binaryBuildOptions);
577
+ if (needColKernel) makeGemmKernel ( &colClKernel, commandQueues[0 ], colKernelSource, sourceBuildOptions, &colKernelBinary, colKernelBinarySize, binaryBuildOptions);
578
+ if (needCornerKernel) makeGemmKernel (&cornerClKernel, commandQueues[0 ], cornerKernelSource, sourceBuildOptions, &cornerKernelBinary, cornerKernelBinarySize, binaryBuildOptions);
551
579
const size_t localWorkSize[2 ] = { workGroupNumRows, workGroupNumCols };
552
580
unsigned int numKernelsEnqueued = 0 ;
553
581
@@ -576,7 +604,7 @@ clblasGemm(
576
604
if (needTileKernel) {
577
605
// printf("enqueueing tile kernel\n");
578
606
size_t globalWorkSize[2 ] = {(M/macroTileNumRows)*workGroupNumRows, (N/macroTileNumCols)*workGroupNumCols };
579
- err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], * tileClKernel,
607
+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], tileClKernel,
580
608
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
581
609
globalWorkSize, localWorkSize,
582
610
numEventsInWaitList, eventWaitList,
@@ -591,7 +619,7 @@ clblasGemm(
591
619
if (needRowKernel) {
592
620
// printf("enqueueing row kernel\n");
593
621
size_t globalWorkSize[2 ] = {1 *workGroupNumRows, (N/macroTileNumCols)*workGroupNumCols };
594
- err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], * rowClKernel,
622
+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], rowClKernel,
595
623
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
596
624
globalWorkSize, localWorkSize,
597
625
numEventsInWaitList, eventWaitList,
@@ -606,7 +634,7 @@ clblasGemm(
606
634
if (needColKernel) {
607
635
// printf("enqueueing col kernel\n");
608
636
size_t globalWorkSize[2 ] = { (M/macroTileNumRows)*workGroupNumRows, 1 *workGroupNumCols };
609
- err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], * colClKernel,
637
+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], colClKernel,
610
638
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
611
639
globalWorkSize, localWorkSize,
612
640
numEventsInWaitList, eventWaitList,
@@ -621,7 +649,7 @@ clblasGemm(
621
649
if (needCornerKernel) {
622
650
// printf("enqueueing corner kernel\n");
623
651
size_t globalWorkSize[2 ] = { 1 *workGroupNumRows, 1 *workGroupNumCols };
624
- err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], * cornerClKernel,
652
+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], cornerClKernel,
625
653
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
626
654
globalWorkSize, localWorkSize,
627
655
numEventsInWaitList, eventWaitList,
0 commit comments