14
14
* limitations under the License.
15
15
* ************************************************************************/
16
16
17
+ #include < map>
18
+ #include < string>
19
+ #include < sstream>
17
20
#include < stdio.h>
18
21
#include < string.h>
19
22
#include < clBLAS.h>
@@ -86,11 +89,35 @@ bool isZero<DoubleComplex>( DoubleComplex value ) {
86
89
return CREAL (value) == 0 && CIMAG (value) == 0 ;
87
90
};
88
91
92
+ static char *getKernelName (cl_kernel clKernel)
93
+ {
94
+ cl_int err;
95
+ // get kernel name
96
+ size_t kernelNameLength;
97
+ err = clGetKernelInfo (
98
+ clKernel,
99
+ CL_KERNEL_FUNCTION_NAME,
100
+ sizeof (kernelNameLength),
101
+ NULL ,
102
+ &kernelNameLength);
103
+ CL_CHECK (err)
89
104
105
+ char *kernelName = new char [kernelNameLength];
106
+ err = clGetKernelInfo (
107
+ clKernel,
108
+ CL_KERNEL_FUNCTION_NAME,
109
+ kernelNameLength*sizeof (char ),
110
+ kernelName,
111
+ NULL );
112
+ CL_CHECK (err)
113
+
114
+ return kernelName;
115
+ }
90
116
91
117
/* *****************************************************************************
92
118
* Make Gemm Kernel
93
119
*****************************************************************************/
120
+ // FIXME: This function should be returning an error.
94
121
void makeGemmKernel (
95
122
cl_kernel *clKernel,
96
123
cl_command_queue clQueue,
@@ -100,39 +127,47 @@ void makeGemmKernel(
100
127
size_t *kernelBinarySize,
101
128
const char *binaryBuildOptions)
102
129
{
130
+ // TODO: This will need to be converted to thread local when making clBLAS thread safe
131
+ typedef std::map<std::string, cl_kernel> kernel_map_t ;
132
+ static kernel_map_t kernel_map;
133
+
134
+ cl_context clContext;
135
+ cl_device_id clDevice;
103
136
cl_int err;
137
+
138
+ err = clGetCommandQueueInfo ( clQueue, CL_QUEUE_CONTEXT, sizeof (clContext), &clContext, NULL );
139
+ CL_CHECK (err)
140
+ err = clGetCommandQueueInfo ( clQueue, CL_QUEUE_DEVICE, sizeof (clDevice), &clDevice, NULL );
141
+ CL_CHECK (err)
142
+
143
+ std::stringstream ss;
144
+ ss << clDevice << " _" << clContext;
145
+ std::string prefix = ss.str ();
146
+
104
147
if (*clKernel) {
148
+ char *kernelName = getKernelName (*clKernel);
105
149
// kernel has already been built, return
106
150
#ifdef AUTOGEMM_PRINT_DEBUG
107
- // get kernel name
108
- size_t kernelNameLength;
109
- err = clGetKernelInfo (
110
- *clKernel,
111
- CL_KERNEL_FUNCTION_NAME,
112
- sizeof (kernelNameLength),
113
- NULL ,
114
- &kernelNameLength );
115
- CL_CHECK (err)
116
- char *kernelName = new char [kernelNameLength];
117
- err = clGetKernelInfo (
118
- *clKernel,
119
- CL_KERNEL_FUNCTION_NAME,
120
- kernelNameLength*sizeof (char ),
121
- kernelName,
122
- NULL );
123
- CL_CHECK (err)
124
151
printf (" makeGemmKernel: \" %s\" already built; returning.\n " , kernelName);
125
- delete[] kernelName;
126
152
#endif
127
- return ;
128
- } else {
153
+
154
+ // Check if kernel exists for this device
155
+ std::string key = prefix + " _" + kernelName;
156
+ kernel_map_t ::iterator idx = kernel_map.find (key);
157
+
158
+
159
+ // If kernel not found for this device, set to NULL
160
+ if (idx == kernel_map.end ()) {
161
+ *clKernel = NULL ;
162
+ } else {
163
+ *clKernel = idx->second ;
164
+ }
165
+
166
+ delete[] kernelName;
167
+ }
168
+
169
+ if (!*clKernel) {
129
170
// kernel has not been built, so build it (from binary, preferably)
130
- cl_context clContext;
131
- cl_device_id clDevice;
132
- err = clGetCommandQueueInfo ( clQueue, CL_QUEUE_CONTEXT, sizeof (clContext), &clContext, NULL );
133
- CL_CHECK (err)
134
- err = clGetCommandQueueInfo ( clQueue, CL_QUEUE_DEVICE, sizeof (clDevice), &clDevice, NULL );
135
- CL_CHECK (err)
136
171
cl_program clProgram;
137
172
cl_int clBinaryStatus;
138
173
if (*kernelBinary) {
@@ -151,6 +186,9 @@ void makeGemmKernel(
151
186
binaryBuildOptions, NULL , NULL );
152
187
CL_CHECK (err)
153
188
} else {
189
+ #ifdef AUTOGEMM_PRINT_DEBUG
190
+ printf (" makeGemmKernel: Creating program from source\n " , *kernelBinarySize);
191
+ #endif
154
192
clProgram = clCreateProgramWithSource (
155
193
clContext,
156
194
1 , &kernelSource,
@@ -178,6 +216,7 @@ void makeGemmKernel(
178
216
printf (" %s\n " , buildLog);
179
217
// printf("\n\nKernel String:\n\n");
180
218
// printf("%s\n", kernelSource);
219
+ // FIXME: The function should be exiting at this point
181
220
}
182
221
183
222
err = clCreateKernelsInProgram (
@@ -187,32 +226,21 @@ void makeGemmKernel(
187
226
CL_CHECK (err)
188
227
err = clReleaseProgram (clProgram);
189
228
CL_CHECK (err)
190
-
229
+
230
+ char *kernelName = getKernelName (*clKernel);
231
+
191
232
#ifdef AUTOGEMM_PRINT_DEBUG
192
- // get kernel name
193
- size_t kernelNameLength;
194
- err = clGetKernelInfo (
195
- *clKernel,
196
- CL_KERNEL_FUNCTION_NAME,
197
- sizeof (kernelNameLength),
198
- NULL ,
199
- &kernelNameLength );
200
- CL_CHECK (err)
201
- char *kernelName = new char [kernelNameLength];
202
- err = clGetKernelInfo (
203
- *clKernel,
204
- CL_KERNEL_FUNCTION_NAME,
205
- kernelNameLength*sizeof (char ),
206
- kernelName,
207
- NULL );
208
- CL_CHECK (err)
209
233
printf (" makeGemmKernel: \" %s\" now built; returning.\n " , kernelName);
210
- delete[] kernelName;
211
234
#endif
235
+
236
+ std::string key = prefix + " _" + kernelName;
237
+ kernel_map[key] = *clKernel;
238
+ delete[] kernelName;
212
239
}
240
+
241
+ return ;
213
242
}
214
243
215
-
216
244
/* *****************************************************************************
217
245
* Enqueue Gemm Kernel
218
246
*****************************************************************************/
@@ -266,7 +294,7 @@ template<> clblasTranspose correctTranspose<DoubleComplex>( clblasTranspose tran
266
294
* templated Gemm
267
295
*****************************************************************************/
268
296
template <typename Precision>
269
- clblasStatus
297
+ clblasStatus
270
298
clblasGemm (
271
299
clblasOrder order,
272
300
clblasTranspose transA,
@@ -308,7 +336,7 @@ clblasGemm(
308
336
M, N, offA, offB, lda, ldb, A, B );
309
337
310
338
311
-
339
+
312
340
/* *****************************************************************************
313
341
* Handle Special Cases
314
342
*
@@ -318,7 +346,7 @@ clblasGemm(
318
346
* and are mod32 but not mod96 or mod64
319
347
*
320
348
*****************************************************************************/
321
-
349
+
322
350
bool specialCaseHandled = false ;
323
351
324
352
clblasStatus SpecialCaseStatus = GemmSpecialCases<Precision>(order,
@@ -339,8 +367,8 @@ clblasGemm(
339
367
340
368
if (specialCaseHandled)
341
369
return SpecialCaseStatus;
342
-
343
-
370
+
371
+
344
372
/* *****************************************************************************
345
373
* Optimal num elements per thread
346
374
*****************************************************************************/
@@ -512,7 +540,7 @@ clblasGemm(
512
540
gemmKernelArgs[11 ] = &offA; gemmKernelArgSizes[11 ] = sizeof (cl_uint);
513
541
gemmKernelArgs[12 ] = &offB; gemmKernelArgSizes[12 ] = sizeof (cl_uint);
514
542
gemmKernelArgs[13 ] = &offC; gemmKernelArgSizes[13 ] = sizeof (cl_uint);
515
-
543
+
516
544
517
545
/* *****************************************************************************
518
546
* Enqueue Tile kernel
@@ -577,8 +605,8 @@ clblasGemm(
577
605
/* *****************************************************************************
578
606
* SGEMM API call
579
607
*****************************************************************************/
580
- extern " C"
581
- clblasStatus
608
+ extern " C"
609
+ clblasStatus
582
610
clblasSgemm (
583
611
clblasOrder order,
584
612
clblasTranspose transA,
@@ -615,7 +643,7 @@ clblasSgemm(
615
643
/* *****************************************************************************
616
644
* DGEMM API call
617
645
*****************************************************************************/
618
- extern " C"
646
+ extern " C"
619
647
clblasStatus
620
648
clblasDgemm ( clblasOrder order,
621
649
clblasTranspose transA,
@@ -652,7 +680,7 @@ clblasDgemm( clblasOrder order,
652
680
/* *****************************************************************************
653
681
* CGEMM API call
654
682
*****************************************************************************/
655
- extern " C"
683
+ extern " C"
656
684
clblasStatus
657
685
clblasCgemm (
658
686
clblasOrder order,
@@ -690,7 +718,7 @@ clblasCgemm(
690
718
/* *****************************************************************************
691
719
* ZGEMM API
692
720
*****************************************************************************/
693
- extern " C"
721
+ extern " C"
694
722
clblasStatus
695
723
clblasZgemm (
696
724
clblasOrder order,
0 commit comments