Skip to content

Commit 15548cf

Browse files
author
Kent Knox
committed
Merge pull request #235 from guacamoleo/develop
proposed fix for gemm thread safety; using thread-local storage
2 parents 1c5ba46 + c590881 commit 15548cf

File tree

1 file changed

+83
-55
lines changed

1 file changed

+83
-55
lines changed

src/library/blas/xgemm.cc

Lines changed: 83 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,20 @@
2020
#include <stdio.h>
2121
#include <string.h>
2222
#include <clBLAS.h>
23+
#include "mutex.h"
2324
#include "AutoGemmIncludes/AutoGemmKernelSelection.h"
2425
#include "GemmSpecialCases.h"
2526

2627
#include <functor.h>
2728
// #include <functor_selector.h>
2829
#include "xgemm.h"
2930

31+
#ifdef _WIN32
32+
//#include <thread>
33+
#else
34+
#include <pthread.h>
35+
#endif
36+
3037
/******************************************************************************
3138
* Row major -> column major
3239
*****************************************************************************/
@@ -120,22 +127,54 @@ static char *getKernelName(cl_kernel clKernel)
120127
return kernelName;
121128
}
122129

130+
typedef struct kernel_map_key_ {
131+
cl_context context; // address of context
132+
cl_device_id device; // address of device
133+
const char *kernelSource; // address of kernel source
134+
} kernel_map_key;
135+
136+
bool operator<(const kernel_map_key & l, const kernel_map_key & r) {
137+
if (l.context < r.context) {
138+
return true;
139+
} else if (r.context < l.context) {
140+
return false;
141+
}
142+
if (l.device < r.device) {
143+
return true;
144+
} else if (r.device < l.device) {
145+
return false;
146+
}
147+
if (l.kernelSource < r.kernelSource) {
148+
return true;
149+
} else if (r.kernelSource < l.kernelSource) {
150+
return false;
151+
}
152+
return false;
153+
}
154+
155+
123156
/******************************************************************************
124157
* Make Gemm Kernel
125158
*****************************************************************************/
126159
//FIXME: This function should be returning an error.
127160
void makeGemmKernel(
128-
cl_kernel *clKernel,
161+
cl_kernel *clKernel, // ignored as input; returns as output only
129162
cl_command_queue clQueue,
130163
const char *kernelSource,
131164
const char *sourceBuildOptions,
132165
const unsigned char **kernelBinary,
133166
size_t *kernelBinarySize,
134167
const char *binaryBuildOptions)
135168
{
136-
//TODO: This will need to be converted to thread local when making clBLAS thread safe
137-
typedef std::map<std::string, cl_kernel> kernel_map_t;
138-
static kernel_map_t kernel_map;
169+
typedef std::map<kernel_map_key, cl_kernel> kernel_map_t;
170+
#if defined( _WIN32 )
171+
__declspec( thread ) static kernel_map_t *kernel_map = 0;
172+
#else
173+
__thread static kernel_map_t *kernel_map = 0;
174+
#endif
175+
if (!kernel_map) {
176+
kernel_map = new kernel_map_t();
177+
}
139178

140179
cl_context clContext;
141180
cl_device_id clDevice;
@@ -146,33 +185,20 @@ void makeGemmKernel(
146185
err = clGetCommandQueueInfo( clQueue, CL_QUEUE_DEVICE, sizeof(clDevice), &clDevice, NULL);
147186
CL_CHECK(err)
148187

149-
std::stringstream ss;
150-
ss << clDevice << "_" << clContext;
151-
std::string prefix = ss.str();
152-
153-
if (*clKernel) {
154-
char *kernelName = getKernelName(*clKernel);
155-
// kernel has already been built, return
156-
#ifdef AUTOGEMM_PRINT_DEBUG
157-
printf("makeGemmKernel: \"%s\" already built; returning.\n", kernelName);
158-
#endif
159-
160-
// Check if kernel exists for this device
161-
std::string key = prefix + "_" + kernelName;
162-
kernel_map_t::iterator idx = kernel_map.find(key);
163-
164-
165-
// If kernel not found for this device, set to NULL
166-
if (idx == kernel_map.end()) {
167-
*clKernel = NULL;
168-
} else {
169-
*clKernel = idx->second;
170-
}
171-
172-
delete[] kernelName;
188+
// is kernel already compiled?
189+
kernel_map_key key;
190+
key.kernelSource = kernelSource;
191+
key.context = clContext;
192+
key.device = clDevice;
193+
kernel_map_t::iterator idx = kernel_map->find(key);
194+
if (idx == kernel_map->end()) {
195+
*clKernel = NULL;
196+
} else {
197+
*clKernel = idx->second;
198+
return;
173199
}
174200

175-
if (!*clKernel) {
201+
if (true /*!*clKernel*/) { // since kernel wasn't found in map
176202
// kernel has not been built, so build it (from binary, preferably)
177203
cl_program clProgram;
178204
cl_int clBinaryStatus;
@@ -244,17 +270,13 @@ void makeGemmKernel(
244270
err = clReleaseProgram(clProgram);
245271
CL_CHECK(err)
246272

247-
char *kernelName = getKernelName(*clKernel);
248-
249273
#ifdef AUTOGEMM_PRINT_DEBUG
250274
printf("makeGemmKernel: \"%s\" now built; returning.\n", kernelName);
251275
#endif
252276

253-
std::string key = prefix + "_" + kernelName;
254-
kernel_map[key] = *clKernel;
255-
delete[] kernelName;
277+
//put kernel in map
278+
(*kernel_map)[key] = *clKernel;
256279
}
257-
258280
return;
259281
}
260282

@@ -439,10 +461,10 @@ clblasGemm(
439461
size_t *colKernelBinarySize = 0;
440462
size_t *cornerKernelBinarySize = 0;
441463
const char *binaryBuildOptions = NULL;
442-
cl_kernel *tileClKernel = NULL;
443-
cl_kernel *rowClKernel = NULL;
444-
cl_kernel *colClKernel = NULL;
445-
cl_kernel *cornerClKernel = NULL;
464+
cl_kernel *tileClKernelDummy = NULL; // no longer used; broke thread safety
465+
cl_kernel *rowClKernelDummy = NULL; // no longer used; broke thread safety
466+
cl_kernel *colClKernelDummy = NULL; // no longer used; broke thread safety
467+
cl_kernel *cornerClKernelDummy = NULL; // no longer used; broke thread safety
446468
unsigned int workGroupNumRows;
447469
unsigned int workGroupNumCols;
448470
unsigned int microTileNumRows;
@@ -467,10 +489,10 @@ clblasGemm(
467489
&colKernelBinarySize,
468490
&cornerKernelBinarySize,
469491
&binaryBuildOptions,
470-
&tileClKernel,
471-
&rowClKernel,
472-
&colClKernel,
473-
&cornerClKernel,
492+
&tileClKernelDummy,
493+
&rowClKernelDummy,
494+
&colClKernelDummy,
495+
&cornerClKernelDummy,
474496
&workGroupNumRows,
475497
&workGroupNumCols,
476498
&microTileNumRows,
@@ -508,10 +530,10 @@ clblasGemm(
508530
&colKernelBinarySize,
509531
&cornerKernelBinarySize,
510532
&binaryBuildOptions,
511-
&tileClKernel,
512-
&rowClKernel,
513-
&colClKernel,
514-
&cornerClKernel,
533+
&tileClKernelDummy,
534+
&rowClKernelDummy,
535+
&colClKernelDummy,
536+
&cornerClKernelDummy,
515537
&workGroupNumRows,
516538
&workGroupNumCols,
517539
&microTileNumRows,
@@ -544,10 +566,16 @@ clblasGemm(
544566
/******************************************************************************
545567
* Build kernels
546568
*****************************************************************************/
547-
if (needTileKernel) makeGemmKernel( tileClKernel, commandQueues[0], tileKernelSource, sourceBuildOptions, &tileKernelBinary, tileKernelBinarySize, binaryBuildOptions);
548-
if (needRowKernel) makeGemmKernel( rowClKernel, commandQueues[0], rowKernelSource, sourceBuildOptions, &rowKernelBinary, rowKernelBinarySize, binaryBuildOptions);
549-
if (needColKernel) makeGemmKernel( colClKernel, commandQueues[0], colKernelSource, sourceBuildOptions, &colKernelBinary, colKernelBinarySize, binaryBuildOptions);
550-
if (needCornerKernel) makeGemmKernel(cornerClKernel, commandQueues[0], cornerKernelSource, sourceBuildOptions, &cornerKernelBinary, cornerKernelBinarySize, binaryBuildOptions);
569+
570+
571+
cl_kernel tileClKernel = NULL;
572+
cl_kernel rowClKernel = NULL;
573+
cl_kernel colClKernel = NULL;
574+
cl_kernel cornerClKernel = NULL;
575+
if (needTileKernel) makeGemmKernel( &tileClKernel, commandQueues[0], tileKernelSource, sourceBuildOptions, &tileKernelBinary, tileKernelBinarySize, binaryBuildOptions);
576+
if (needRowKernel) makeGemmKernel( &rowClKernel, commandQueues[0], rowKernelSource, sourceBuildOptions, &rowKernelBinary, rowKernelBinarySize, binaryBuildOptions);
577+
if (needColKernel) makeGemmKernel( &colClKernel, commandQueues[0], colKernelSource, sourceBuildOptions, &colKernelBinary, colKernelBinarySize, binaryBuildOptions);
578+
if (needCornerKernel) makeGemmKernel(&cornerClKernel, commandQueues[0], cornerKernelSource, sourceBuildOptions, &cornerKernelBinary, cornerKernelBinarySize, binaryBuildOptions);
551579
const size_t localWorkSize[2] = { workGroupNumRows, workGroupNumCols };
552580
unsigned int numKernelsEnqueued = 0;
553581

@@ -576,7 +604,7 @@ clblasGemm(
576604
if (needTileKernel) {
577605
//printf("enqueueing tile kernel\n");
578606
size_t globalWorkSize[2] = {(M/macroTileNumRows)*workGroupNumRows, (N/macroTileNumCols)*workGroupNumCols };
579-
err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], *tileClKernel,
607+
err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], tileClKernel,
580608
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
581609
globalWorkSize, localWorkSize,
582610
numEventsInWaitList, eventWaitList,
@@ -591,7 +619,7 @@ clblasGemm(
591619
if (needRowKernel) {
592620
//printf("enqueueing row kernel\n");
593621
size_t globalWorkSize[2] = {1*workGroupNumRows, (N/macroTileNumCols)*workGroupNumCols };
594-
err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], *rowClKernel,
622+
err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], rowClKernel,
595623
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
596624
globalWorkSize, localWorkSize,
597625
numEventsInWaitList, eventWaitList,
@@ -606,7 +634,7 @@ clblasGemm(
606634
if (needColKernel) {
607635
//printf("enqueueing col kernel\n");
608636
size_t globalWorkSize[2] = { (M/macroTileNumRows)*workGroupNumRows, 1*workGroupNumCols };
609-
err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], *colClKernel,
637+
err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], colClKernel,
610638
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
611639
globalWorkSize, localWorkSize,
612640
numEventsInWaitList, eventWaitList,
@@ -621,7 +649,7 @@ clblasGemm(
621649
if (needCornerKernel) {
622650
//printf("enqueueing corner kernel\n");
623651
size_t globalWorkSize[2] = { 1*workGroupNumRows, 1*workGroupNumCols };
624-
err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], *cornerClKernel,
652+
err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], cornerClKernel,
625653
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
626654
globalWorkSize, localWorkSize,
627655
numEventsInWaitList, eventWaitList,

0 commit comments

Comments
 (0)