Skip to content

Commit f3a10ab

Browse files
author
Timmy
committed
fix sgemm NT perf drop when fix sgemm NT perf drop when lda=ldb=7168 or 8192 and k>lda/4
1 parent 458c9da commit f3a10ab

File tree

1 file changed

+290
-11
lines changed

1 file changed

+290
-11
lines changed

src/library/blas/functor/hawaii_sgemmBig1024Kernel.cc

Lines changed: 290 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -289,21 +289,300 @@ clBlashawaiiSgemmBig1024KernelFunctor::provide(clblasSgemmFunctor::Args & args,
289289

290290
cl_int clBlashawaiiSgemmBig1024KernelFunctor::KernelsLaunch(cl_command_queue queue, cl_kernel Kernel[1], Args &args)
291291
{
292-
//((Mvalue - 1) / 128 + 1) * 16
293-
size_t GlobalX = ((args.M-1) / 128 + 1)*16 ;
294-
295-
//
296-
297-
size_t GlobalY = ((args.N - 1) / 128 + 1) * 16;
292+
if (args.lda < 7168)
293+
{
294+
//((Mvalue - 1) / 128 + 1) * 16
295+
size_t GlobalX = ((args.M - 1) / 128 + 1) * 16;
296+
size_t GlobalY = ((args.N - 1) / 128 + 1) * 16;
298297

299298

300-
std::size_t gs[2] = {GlobalX, GlobalY};
301-
cl_int error = 0;
299+
std::size_t gs[2] = { GlobalX, GlobalY };
300+
cl_int error = 0;
302301

303302

304-
//if (VERB) printf(" ===> EXECUTE KERNEL 0 \n") ;
305-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL, gs, m_variantBig1024->ls, args.numEventsInWaitList, args.eventWaitList,args.events);
306-
return error;
303+
//if (VERB) printf(" ===> EXECUTE KERNEL 0 \n") ;
304+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL, gs, m_variantBig1024->ls, args.numEventsInWaitList, args.eventWaitList, args.events);
305+
return error;
306+
}
307+
else
308+
{
309+
//for example, when M=N=K=8192
310+
//we are gonna call 16 GEMMs
311+
//each GEMM has M=N=K=4096
312+
//note are direct GEMM call has a 0.7 TFLOPS performance
313+
314+
// [ A11 | A12 | A13 | A14 ] [ B11 | B12 | B13 | B14 ] [ C11 | C12 ]
315+
// A = [ A21 | A22 | A23 | A24 ] B = [ B21 | B22 | B23 | B24 ] C = [ C21 | C22 ]
316+
317+
// 16 GEMMs are
318+
// #01: C11 = a*A11*B11 + b*C11
319+
// #02: C11 = a*A12*B12 + 1*C11
320+
// #03: C11 = a*A13*B13 + 1*C11
321+
// #04: C11 = a*A14*B14 + 1*C11 now we are done with C11
322+
323+
// #05: C12 = a*A11*B21 + b*C12
324+
// #06: C12 = a*A12*B22 + 1*C12
325+
// #07: C12 = a*A12*B22 + 1*C12
326+
// #08: C12 = a*A12*B22 + 1*C12 now we are done with C12
327+
328+
// #09: C21 = a*A21*B11 + b*C21
329+
// #10: C21 = a*A22*B12 + 1*C21
330+
// #11: C21 = a*A23*B13 + 1*C21
331+
// #12: C21 = a*A24*B14 + 1*C21 now we are done with C21
332+
333+
// #13: C22 = a*A21*B21 + b*C22
334+
// #14: C22 = a*A22*B22 + 1*C22
335+
// #15: C22 = a*A23*B23 + 1*C22
336+
// #16: C22 = a*A24*B24 + 1*C22 now we are done with C22
337+
338+
unsigned int small_M = args.M / 2;
339+
unsigned int small_N = args.N / 2;
340+
unsigned int small_K = args.K / 4;
341+
342+
size_t GlobalX = ((small_M - 1) / 128 + 1) * 16;
343+
size_t GlobalY = ((small_N - 1) / 128 + 1) * 16;
344+
std::size_t gs[2] = { GlobalX, GlobalY };
345+
cl_int error = 0;
346+
347+
//GEMM #1
348+
error = clSetKernelArg(Kernel[0], 3, sizeof(cl_uint), &small_M);
349+
assert(error == CL_SUCCESS);
350+
error = clSetKernelArg(Kernel[0], 4, sizeof(cl_uint), &small_N);
351+
assert(error == CL_SUCCESS);
352+
error = clSetKernelArg(Kernel[0], 5, sizeof(cl_uint), &small_K);
353+
assert(error == CL_SUCCESS);
354+
355+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
356+
gs, m_variantBig1024->ls, args.numEventsInWaitList, args.eventWaitList, NULL);
357+
assert(error == CL_SUCCESS);
358+
359+
// GEMM #2: C11 = a*A12*B12 + 1*C11
360+
361+
unsigned int offa_A2 = args.lda*args.K / 4;
362+
unsigned int offb_B2 = args.ldb*args.K / 4;
363+
cl_float betaone = 1;
364+
365+
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &betaone);
366+
assert(error == CL_SUCCESS);
367+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A2);
368+
assert(error == CL_SUCCESS);
369+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B2);
370+
assert(error == CL_SUCCESS);
371+
372+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
373+
gs, m_variantBig1024->ls, 0, NULL, NULL);
374+
assert(error == CL_SUCCESS);
375+
376+
// GEMM #3: C11 = a*A13*B13 + 1*C11
377+
378+
unsigned int offa_A3 = args.lda*args.K / 4 * 2;
379+
unsigned int offb_B3 = args.ldb*args.K / 4 * 2;
380+
381+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A3);
382+
assert(error == CL_SUCCESS);
383+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B3);
384+
assert(error == CL_SUCCESS);
385+
386+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
387+
gs, m_variantBig1024->ls, 0, NULL, NULL);
388+
assert(error == CL_SUCCESS);
389+
390+
// GEMM #4: C11 = a*A14*B14 + 1*C11 now we are done with C11
391+
392+
unsigned int offa_A4 = args.lda*args.K / 4 * 3;
393+
unsigned int offb_B4 = args.ldb*args.K / 4 * 3;
394+
395+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A4);
396+
assert(error == CL_SUCCESS);
397+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B4);
398+
assert(error == CL_SUCCESS);
399+
400+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
401+
gs, m_variantBig1024->ls, 0, NULL, NULL);
402+
assert(error == CL_SUCCESS);
403+
404+
// GEMM #5: C12 = a*A11*B21 + b*C12
405+
unsigned int offa_A5 = 0;
406+
unsigned int offb_B5 = args.N / 2;
407+
unsigned int offc_C5 = args.ldc*args.N / 2;
408+
409+
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &(args.beta));
410+
assert(error == CL_SUCCESS);
411+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A5);
412+
assert(error == CL_SUCCESS);
413+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B5);
414+
assert(error == CL_SUCCESS);
415+
error = clSetKernelArg(Kernel[0], 13, sizeof(cl_uint), &offc_C5);
416+
assert(error == CL_SUCCESS);
417+
418+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
419+
gs, m_variantBig1024->ls, 0, NULL, NULL);
420+
assert(error == CL_SUCCESS);
421+
422+
// GEMM #6: C12 = a*A12*B22 + 1*C12
423+
unsigned int offa_A6 = args.lda*args.K / 4;
424+
unsigned int offb_B6 = args.ldb*args.K / 4 + args.N / 2;
425+
426+
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &betaone);
427+
assert(error == CL_SUCCESS);
428+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A6);
429+
assert(error == CL_SUCCESS);
430+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B6);
431+
assert(error == CL_SUCCESS);
432+
433+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
434+
gs, m_variantBig1024->ls, 0, NULL, NULL);
435+
assert(error == CL_SUCCESS);
436+
437+
// GEMM #7: C12 = a*A13*B23 + 1*C12
438+
unsigned int offa_A7 = args.lda*args.K / 4 * 2;
439+
unsigned int offb_B7 = args.ldb*args.K / 4 * 2 + args.N / 2;
440+
441+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A7);
442+
assert(error == CL_SUCCESS);
443+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B7);
444+
assert(error == CL_SUCCESS);
445+
446+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
447+
gs, m_variantBig1024->ls, 0, NULL, NULL);
448+
assert(error == CL_SUCCESS);
449+
450+
// GEMM #8: C12 = a*A14*B24 + 1*C12 now we are done with C12
451+
unsigned int offa_A8 = args.lda*args.K / 4 * 3;
452+
unsigned int offb_B8 = args.ldb*args.K / 4 * 3 + args.N / 2;
453+
454+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A8);
455+
assert(error == CL_SUCCESS);
456+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B8);
457+
assert(error == CL_SUCCESS);
458+
459+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
460+
gs, m_variantBig1024->ls, 0, NULL, NULL);
461+
assert(error == CL_SUCCESS);
462+
463+
// GEMM #9: C21 = a*A21*B11 + b*C21
464+
unsigned int offa_A9 = args.M / 2;
465+
unsigned int offb_B9 = 0;
466+
unsigned int offc_C9 = args.M / 2;
467+
468+
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &(args.beta));
469+
assert(error == CL_SUCCESS);
470+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A9);
471+
assert(error == CL_SUCCESS);
472+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B9);
473+
assert(error == CL_SUCCESS);
474+
error = clSetKernelArg(Kernel[0], 13, sizeof(cl_uint), &offc_C9);
475+
assert(error == CL_SUCCESS);
476+
477+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
478+
gs, m_variantBig1024->ls, 0, NULL, NULL);
479+
assert(error == CL_SUCCESS);
480+
481+
// GEMM #10: C21 = a*A22*B12 + 1*C21
482+
483+
unsigned int offa_A10 = args.lda*args.K / 4 + args.M / 2;
484+
unsigned int offb_B10 = args.ldb*args.K / 4;
485+
486+
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &betaone);
487+
assert(error == CL_SUCCESS);
488+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A10);
489+
assert(error == CL_SUCCESS);
490+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B10);
491+
assert(error == CL_SUCCESS);
492+
493+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
494+
gs, m_variantBig1024->ls, 0, NULL, NULL);
495+
assert(error == CL_SUCCESS);
496+
497+
// GEMM #11: C21 = a*A23*B13 + 1*C21
498+
499+
unsigned int offa_A11 = args.lda*args.K / 4 * 2 + args.M / 2;
500+
unsigned int offb_B11 = args.ldb*args.K / 4 * 2;
501+
502+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A11);
503+
assert(error == CL_SUCCESS);
504+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B11);
505+
assert(error == CL_SUCCESS);
506+
507+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
508+
gs, m_variantBig1024->ls, 0, NULL, NULL);
509+
assert(error == CL_SUCCESS);
510+
511+
// GEMM #12: C21 = a*A24*B14 + 1*C21 now we are done with C21
512+
513+
unsigned int offa_A12 = args.lda*args.K / 4 * 3 + args.M / 2;
514+
unsigned int offb_B12 = args.ldb*args.K / 4 * 3;
515+
516+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A12);
517+
assert(error == CL_SUCCESS);
518+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B12);
519+
assert(error == CL_SUCCESS);
520+
521+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
522+
gs, m_variantBig1024->ls, 0, NULL, NULL);
523+
assert(error == CL_SUCCESS);
524+
525+
// GEMM #13: C22 = a*A21*B21 + b*C22
526+
unsigned int offa_A13 = args.M / 2;
527+
unsigned int offb_B13 = args.N / 2;
528+
unsigned int offc_C13 = args.ldc*args.N / 2 + args.M / 2;
529+
530+
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &(args.beta));
531+
assert(error == CL_SUCCESS);
532+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A13);
533+
assert(error == CL_SUCCESS);
534+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B13);
535+
assert(error == CL_SUCCESS);
536+
error = clSetKernelArg(Kernel[0], 13, sizeof(cl_uint), &offc_C13);
537+
assert(error == CL_SUCCESS);
538+
539+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
540+
gs, m_variantBig1024->ls, 0, NULL, NULL);
541+
assert(error == CL_SUCCESS);
542+
543+
// #14: C22 = a*A22*B22 + 1*C22
544+
unsigned int offa_A14 = args.lda*args.K / 4 + args.M / 2;
545+
unsigned int offb_B14 = args.ldb*args.K / 4 + args.N / 2;
546+
547+
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &betaone);
548+
assert(error == CL_SUCCESS);
549+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A14);
550+
assert(error == CL_SUCCESS);
551+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B14);
552+
assert(error == CL_SUCCESS);
553+
554+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
555+
gs, m_variantBig1024->ls, 0, NULL, NULL);
556+
assert(error == CL_SUCCESS);
557+
558+
// #15: C22 = a*A23*B23 + 1*C22
559+
unsigned int offa_A15 = args.lda*args.K / 4 * 2 + args.M / 2;
560+
unsigned int offb_B15 = args.ldb*args.K / 4 * 2 + args.N / 2;
561+
562+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A15);
563+
assert(error == CL_SUCCESS);
564+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B15);
565+
assert(error == CL_SUCCESS);
566+
567+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
568+
gs, m_variantBig1024->ls, 0, NULL, NULL);
569+
assert(error == CL_SUCCESS);
570+
571+
// #16: C22 = a*A24*B24 + 1*C22
572+
unsigned int offa_A16 = args.lda*args.K / 4 * 3 + args.M / 2;
573+
unsigned int offb_B16 = args.ldb*args.K / 4 * 3 + args.N / 2;
574+
575+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A16);
576+
assert(error == CL_SUCCESS);
577+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B16);
578+
assert(error == CL_SUCCESS);
579+
580+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
581+
gs, m_variantBig1024->ls, 0, NULL, args.events);
582+
assert(error == CL_SUCCESS);
583+
584+
return error;
585+
}
307586

308587

309588
return clblasNotImplemented;

0 commit comments

Comments
 (0)