Skip to content

Commit 843bff5

Browse files
committed
Add caching mechanism based on context and device for gemm and trsm
1 parent c2e7334 commit 843bff5

File tree

2 files changed

+336
-270
lines changed

2 files changed

+336
-270
lines changed

src/library/blas/xgemm.cc

Lines changed: 85 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
* limitations under the License.
1515
* ************************************************************************/
1616

17+
#include <map>
18+
#include <string>
19+
#include <sstream>
1720
#include <stdio.h>
1821
#include <string.h>
1922
#include <clBLAS.h>
@@ -86,11 +89,35 @@ bool isZero<DoubleComplex>( DoubleComplex value ) {
8689
return CREAL(value) == 0 && CIMAG(value) == 0;
8790
};
8891

92+
static char *getKernelName(cl_kernel clKernel)
93+
{
94+
cl_int err;
95+
// get kernel name
96+
size_t kernelNameLength;
97+
err = clGetKernelInfo(
98+
clKernel,
99+
CL_KERNEL_FUNCTION_NAME,
100+
sizeof(kernelNameLength),
101+
NULL,
102+
&kernelNameLength);
103+
CL_CHECK(err)
89104

105+
char *kernelName = new char[kernelNameLength];
106+
err = clGetKernelInfo(
107+
clKernel,
108+
CL_KERNEL_FUNCTION_NAME,
109+
kernelNameLength*sizeof(char),
110+
kernelName,
111+
NULL );
112+
CL_CHECK(err)
113+
114+
return kernelName;
115+
}
90116

91117
/******************************************************************************
92118
* Make Gemm Kernel
93119
*****************************************************************************/
120+
//FIXME: This function should be returning an error.
94121
void makeGemmKernel(
95122
cl_kernel *clKernel,
96123
cl_command_queue clQueue,
@@ -100,39 +127,47 @@ void makeGemmKernel(
100127
size_t *kernelBinarySize,
101128
const char *binaryBuildOptions)
102129
{
130+
//TODO: This will need to be converted to thread local when making clBLAS thread safe
131+
typedef std::map<std::string, cl_kernel> kernel_map_t;
132+
static kernel_map_t kernel_map;
133+
134+
cl_context clContext;
135+
cl_device_id clDevice;
103136
cl_int err;
137+
138+
err = clGetCommandQueueInfo( clQueue, CL_QUEUE_CONTEXT, sizeof(clContext), &clContext, NULL);
139+
CL_CHECK(err)
140+
err = clGetCommandQueueInfo( clQueue, CL_QUEUE_DEVICE, sizeof(clDevice), &clDevice, NULL);
141+
CL_CHECK(err)
142+
143+
std::stringstream ss;
144+
ss << clDevice << "_" << clContext;
145+
std::string prefix = ss.str();
146+
104147
if (*clKernel) {
148+
char *kernelName = getKernelName(*clKernel);
105149
// kernel has already been built, return
106150
#ifdef AUTOGEMM_PRINT_DEBUG
107-
// get kernel name
108-
size_t kernelNameLength;
109-
err = clGetKernelInfo(
110-
*clKernel,
111-
CL_KERNEL_FUNCTION_NAME,
112-
sizeof(kernelNameLength),
113-
NULL,
114-
&kernelNameLength );
115-
CL_CHECK(err)
116-
char *kernelName = new char[kernelNameLength];
117-
err = clGetKernelInfo(
118-
*clKernel,
119-
CL_KERNEL_FUNCTION_NAME,
120-
kernelNameLength*sizeof(char),
121-
kernelName,
122-
NULL );
123-
CL_CHECK(err)
124151
printf("makeGemmKernel: \"%s\" already built; returning.\n", kernelName);
125-
delete[] kernelName;
126152
#endif
127-
return;
128-
} else {
153+
154+
// Check if kernel exists for this device
155+
std::string key = prefix + "_" + kernelName;
156+
kernel_map_t::iterator idx = kernel_map.find(key);
157+
158+
159+
// If kernel not found for this device, set to NULL
160+
if (idx == kernel_map.end()) {
161+
*clKernel = NULL;
162+
} else {
163+
*clKernel = idx->second;
164+
}
165+
166+
delete[] kernelName;
167+
}
168+
169+
if (!*clKernel) {
129170
// kernel has not been built, so build it (from binary, preferably)
130-
cl_context clContext;
131-
cl_device_id clDevice;
132-
err = clGetCommandQueueInfo( clQueue, CL_QUEUE_CONTEXT, sizeof(clContext), &clContext, NULL);
133-
CL_CHECK(err)
134-
err = clGetCommandQueueInfo( clQueue, CL_QUEUE_DEVICE, sizeof(clDevice), &clDevice, NULL);
135-
CL_CHECK(err)
136171
cl_program clProgram;
137172
cl_int clBinaryStatus;
138173
if (*kernelBinary) {
@@ -151,6 +186,9 @@ void makeGemmKernel(
151186
binaryBuildOptions, NULL, NULL );
152187
CL_CHECK(err)
153188
} else {
189+
#ifdef AUTOGEMM_PRINT_DEBUG
190+
printf("makeGemmKernel: Creating program from source\n", *kernelBinarySize);
191+
#endif
154192
clProgram = clCreateProgramWithSource(
155193
clContext,
156194
1, &kernelSource,
@@ -178,6 +216,7 @@ void makeGemmKernel(
178216
printf("%s\n", buildLog);
179217
//printf("\n\nKernel String:\n\n");
180218
//printf("%s\n", kernelSource);
219+
//FIXME: The function should be exiting at this point
181220
}
182221

183222
err = clCreateKernelsInProgram(
@@ -187,32 +226,21 @@ void makeGemmKernel(
187226
CL_CHECK(err)
188227
err = clReleaseProgram(clProgram);
189228
CL_CHECK(err)
190-
229+
230+
char *kernelName = getKernelName(*clKernel);
231+
191232
#ifdef AUTOGEMM_PRINT_DEBUG
192-
// get kernel name
193-
size_t kernelNameLength;
194-
err = clGetKernelInfo(
195-
*clKernel,
196-
CL_KERNEL_FUNCTION_NAME,
197-
sizeof(kernelNameLength),
198-
NULL,
199-
&kernelNameLength );
200-
CL_CHECK(err)
201-
char *kernelName = new char[kernelNameLength];
202-
err = clGetKernelInfo(
203-
*clKernel,
204-
CL_KERNEL_FUNCTION_NAME,
205-
kernelNameLength*sizeof(char),
206-
kernelName,
207-
NULL );
208-
CL_CHECK(err)
209233
printf("makeGemmKernel: \"%s\" now built; returning.\n", kernelName);
210-
delete[] kernelName;
211234
#endif
235+
236+
std::string key = prefix + "_" + kernelName;
237+
kernel_map[key] = *clKernel;
238+
delete[] kernelName;
212239
}
240+
241+
return;
213242
}
214243

215-
216244
/******************************************************************************
217245
* Enqueue Gemm Kernel
218246
*****************************************************************************/
@@ -266,7 +294,7 @@ template<> clblasTranspose correctTranspose<DoubleComplex>( clblasTranspose tran
266294
* templated Gemm
267295
*****************************************************************************/
268296
template<typename Precision>
269-
clblasStatus
297+
clblasStatus
270298
clblasGemm(
271299
clblasOrder order,
272300
clblasTranspose transA,
@@ -308,7 +336,7 @@ clblasGemm(
308336
M, N, offA, offB, lda, ldb, A, B );
309337

310338

311-
339+
312340
/******************************************************************************
313341
* Handle Special Cases
314342
*
@@ -318,7 +346,7 @@ clblasGemm(
318346
* and are mod32 but not mod96 or mod64
319347
*
320348
*****************************************************************************/
321-
349+
322350
bool specialCaseHandled = false;
323351

324352
clblasStatus SpecialCaseStatus = GemmSpecialCases<Precision>(order,
@@ -339,8 +367,8 @@ clblasGemm(
339367

340368
if (specialCaseHandled)
341369
return SpecialCaseStatus;
342-
343-
370+
371+
344372
/******************************************************************************
345373
* Optimal num elements per thread
346374
*****************************************************************************/
@@ -512,7 +540,7 @@ clblasGemm(
512540
gemmKernelArgs[11] = &offA; gemmKernelArgSizes[11] = sizeof(cl_uint);
513541
gemmKernelArgs[12] = &offB; gemmKernelArgSizes[12] = sizeof(cl_uint);
514542
gemmKernelArgs[13] = &offC; gemmKernelArgSizes[13] = sizeof(cl_uint);
515-
543+
516544

517545
/******************************************************************************
518546
* Enqueue Tile kernel
@@ -577,8 +605,8 @@ clblasGemm(
577605
/******************************************************************************
578606
* SGEMM API call
579607
*****************************************************************************/
580-
extern "C"
581-
clblasStatus
608+
extern "C"
609+
clblasStatus
582610
clblasSgemm(
583611
clblasOrder order,
584612
clblasTranspose transA,
@@ -615,7 +643,7 @@ clblasSgemm(
615643
/******************************************************************************
616644
* DGEMM API call
617645
*****************************************************************************/
618-
extern "C"
646+
extern "C"
619647
clblasStatus
620648
clblasDgemm( clblasOrder order,
621649
clblasTranspose transA,
@@ -652,7 +680,7 @@ clblasDgemm( clblasOrder order,
652680
/******************************************************************************
653681
* CGEMM API call
654682
*****************************************************************************/
655-
extern "C"
683+
extern "C"
656684
clblasStatus
657685
clblasCgemm(
658686
clblasOrder order,
@@ -690,7 +718,7 @@ clblasCgemm(
690718
/******************************************************************************
691719
* ZGEMM API
692720
*****************************************************************************/
693-
extern "C"
721+
extern "C"
694722
clblasStatus
695723
clblasZgemm(
696724
clblasOrder order,

0 commit comments

Comments
 (0)