Skip to content

Commit 7a74778

Browse files
committed
compiling kernels is now thread safe; not using global cl_kernel objects
1 parent 02cf387 commit 7a74778

File tree

1 file changed

+48
-34
lines changed

1 file changed

+48
-34
lines changed

src/library/blas/xgemm.cc

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <stdio.h>
2121
#include <string.h>
2222
#include <clBLAS.h>
23+
#include "mutex.h"
2324
#include "AutoGemmIncludes/AutoGemmKernelSelection.h"
2425
#include "GemmSpecialCases.h"
2526

@@ -126,22 +127,47 @@ static char *getKernelName(cl_kernel clKernel)
126127
return kernelName;
127128
}
128129

130+
typedef struct kernel_map_key_ {
131+
cl_context context; // address of context
132+
cl_device_id device; // address of device
133+
const char *kernelSource; // address of kernel source
134+
} kernel_map_key;
135+
136+
bool operator<(const kernel_map_key & l, const kernel_map_key & r) {
137+
if (l.context < r.context) {
138+
return true;
139+
} else if (r.context < l.context) {
140+
return false;
141+
}
142+
if (l.device < r.device) {
143+
return true;
144+
} else if (r.device < l.device) {
145+
return false;
146+
}
147+
if (l.kernelSource < r.kernelSource) {
148+
return true;
149+
} else if (r.kernelSource < l.kernelSource) {
150+
return false;
151+
}
152+
return false;
153+
}
154+
155+
129156
/******************************************************************************
130157
* Make Gemm Kernel
131158
*****************************************************************************/
132159
//FIXME: This function should be returning an error.
133160
void makeGemmKernel(
134-
cl_kernel *clKernel,
161+
cl_kernel *clKernel, // ignored as input; returns as output
135162
cl_command_queue clQueue,
136163
const char *kernelSource,
137164
const char *sourceBuildOptions,
138165
const unsigned char **kernelBinary,
139166
size_t *kernelBinarySize,
140167
const char *binaryBuildOptions)
141168
{
142-
typedef std::map<std::string, cl_kernel> kernel_map_t;
143-
144-
#if defined( _WIN32 )
169+
typedef std::map<kernel_map_key, cl_kernel> kernel_map_t;
170+
#if defined( _WIN32 )
145171
__declspec( thread ) static kernel_map_t *kernel_map = 0;
146172
#else
147173
__thread static kernel_map_t *kernel_map = 0;
@@ -159,33 +185,20 @@ void makeGemmKernel(
159185
err = clGetCommandQueueInfo( clQueue, CL_QUEUE_DEVICE, sizeof(clDevice), &clDevice, NULL);
160186
CL_CHECK(err)
161187

162-
std::stringstream ss;
163-
ss << clDevice << "_" << clContext;
164-
std::string prefix = ss.str();
165-
166-
if (*clKernel) {
167-
char *kernelName = getKernelName(*clKernel);
168-
// kernel has already been built, return
169-
#ifdef AUTOGEMM_PRINT_DEBUG
170-
printf("makeGemmKernel: \"%s\" already built; returning.\n", kernelName);
171-
#endif
172-
173-
// Check if kernel exists for this device
174-
std::string key = prefix + "_" + kernelName;
175-
kernel_map_t::iterator idx = kernel_map->find(key);
176-
177-
178-
// If kernel not found for this device, set to NULL
179-
if (idx == kernel_map->end()) {
180-
*clKernel = NULL;
181-
} else {
182-
*clKernel = idx->second;
183-
}
184-
185-
delete[] kernelName;
188+
// is kernel already compiled?
189+
kernel_map_key key;
190+
key.kernelSource = kernelSource;
191+
key.context = clContext;
192+
key.device = clDevice;
193+
kernel_map_t::iterator idx = kernel_map->find(key);
194+
if (idx == kernel_map->end()) {
195+
*clKernel = NULL;
196+
} else {
197+
*clKernel = idx->second;
198+
return;
186199
}
187200

188-
if (!*clKernel) {
201+
if (true /*!*clKernel*/) { // since kernel wasn't found in map
189202
// kernel has not been built, so build it (from binary, preferably)
190203
cl_program clProgram;
191204
cl_int clBinaryStatus;
@@ -257,17 +270,13 @@ void makeGemmKernel(
257270
err = clReleaseProgram(clProgram);
258271
CL_CHECK(err)
259272

260-
char *kernelName = getKernelName(*clKernel);
261-
262273
#ifdef AUTOGEMM_PRINT_DEBUG
263274
printf("makeGemmKernel: \"%s\" now built; returning.\n", kernelName);
264275
#endif
265276

266-
std::string key = prefix + "_" + kernelName;
277+
//put kernel in map
267278
(*kernel_map)[key] = *clKernel;
268-
delete[] kernelName;
269279
}
270-
271280
return;
272281
}
273282

@@ -557,6 +566,11 @@ clblasGemm(
557566
/******************************************************************************
558567
* Build kernels
559568
*****************************************************************************/
569+
570+
tileClKernel = NULL;
571+
rowClKernel = NULL;
572+
colClKernel = NULL;
573+
cornerClKernel = NULL;
560574
if (needTileKernel) makeGemmKernel( tileClKernel, commandQueues[0], tileKernelSource, sourceBuildOptions, &tileKernelBinary, tileKernelBinarySize, binaryBuildOptions);
561575
if (needRowKernel) makeGemmKernel( rowClKernel, commandQueues[0], rowKernelSource, sourceBuildOptions, &rowKernelBinary, rowKernelBinarySize, binaryBuildOptions);
562576
if (needColKernel) makeGemmKernel( colClKernel, commandQueues[0], colKernelSource, sourceBuildOptions, &colKernelBinary, colKernelBinarySize, binaryBuildOptions);

0 commit comments

Comments
 (0)