Skip to content

Commit 5a74faf

Browse files
author
Timmy
committed
code clean up
1 parent f3a10ab commit 5a74faf

File tree

2 files changed

+73
-232
lines changed

2 files changed

+73
-232
lines changed

src/library/blas/functor/hawaii.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ clblasSgemmFunctor * FunctorSelectorHawaii::select_sgemm_specific(clblasSgemmFun
123123
{
124124
if (args.lda != 6144)// 6144 is handled by a special case split
125125
{
126-
if (args.M % 128 == 0 && args.N % 128 == 0 && args.K % 64 == 0)
126+
// we are going to call 16 GEMMs with M=M/2, N=N/2, K=K/4
127+
// each GEMM requires M%128 == 0, N%128 == 0, K%16 == 0
128+
if (args.M % 256 == 0 && args.N % 256 == 0 && args.K % 64 == 0)
127129
{
128130
functor = clBlashawaiiSgemmBig1024KernelFunctor::provide(args, "Hawaii");
129131
if (functor)

src/library/blas/functor/hawaii_sgemmBig1024Kernel.cc

Lines changed: 70 additions & 231 deletions
Original file line numberDiff line numberDiff line change
@@ -335,251 +335,90 @@ cl_int clBlashawaiiSgemmBig1024KernelFunctor::KernelsLaunch(cl_command_queue que
335335
// #15: C22 = a*A23*B23 + 1*C22
336336
// #16: C22 = a*A24*B24 + 1*C22 now we are done with C22
337337

338-
unsigned int small_M = args.M / 2;
339-
unsigned int small_N = args.N / 2;
340-
unsigned int small_K = args.K / 4;
338+
unsigned int K_split_factor = 4;
339+
unsigned int M_split_factor = 2;
340+
unsigned int N_split_factor = 2;
341+
342+
unsigned int small_M = args.M / M_split_factor;
343+
unsigned int small_N = args.N / N_split_factor;
344+
unsigned int small_K = args.K / K_split_factor;
341345

342346
size_t GlobalX = ((small_M - 1) / 128 + 1) * 16;
343347
size_t GlobalY = ((small_N - 1) / 128 + 1) * 16;
344348
std::size_t gs[2] = { GlobalX, GlobalY };
345349
cl_int error = 0;
346350

347-
//GEMM #1
351+
cl_float betaone = 1;
352+
348353
error = clSetKernelArg(Kernel[0], 3, sizeof(cl_uint), &small_M);
349354
assert(error == CL_SUCCESS);
350355
error = clSetKernelArg(Kernel[0], 4, sizeof(cl_uint), &small_N);
351356
assert(error == CL_SUCCESS);
352357
error = clSetKernelArg(Kernel[0], 5, sizeof(cl_uint), &small_K);
353358
assert(error == CL_SUCCESS);
354359

355-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
356-
gs, m_variantBig1024->ls, args.numEventsInWaitList, args.eventWaitList, NULL);
357-
assert(error == CL_SUCCESS);
358-
359-
// GEMM #2: C11 = a*A12*B12 + 1*C11
360-
361-
unsigned int offa_A2 = args.lda*args.K / 4;
362-
unsigned int offb_B2 = args.ldb*args.K / 4;
363-
cl_float betaone = 1;
364-
365-
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &betaone);
366-
assert(error == CL_SUCCESS);
367-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A2);
368-
assert(error == CL_SUCCESS);
369-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B2);
370-
assert(error == CL_SUCCESS);
371-
372-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
373-
gs, m_variantBig1024->ls, 0, NULL, NULL);
374-
assert(error == CL_SUCCESS);
375-
376-
// GEMM #3: C11 = a*A13*B13 + 1*C11
377-
378-
unsigned int offa_A3 = args.lda*args.K / 4 * 2;
379-
unsigned int offb_B3 = args.ldb*args.K / 4 * 2;
380-
381-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A3);
382-
assert(error == CL_SUCCESS);
383-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B3);
384-
assert(error == CL_SUCCESS);
385-
386-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
387-
gs, m_variantBig1024->ls, 0, NULL, NULL);
388-
assert(error == CL_SUCCESS);
389-
390-
// GEMM #4: C11 = a*A14*B14 + 1*C11 now we are done with C11
391-
392-
unsigned int offa_A4 = args.lda*args.K / 4 * 3;
393-
unsigned int offb_B4 = args.ldb*args.K / 4 * 3;
394-
395-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A4);
396-
assert(error == CL_SUCCESS);
397-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B4);
398-
assert(error == CL_SUCCESS);
399-
400-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
401-
gs, m_variantBig1024->ls, 0, NULL, NULL);
402-
assert(error == CL_SUCCESS);
403-
404-
// GEMM #5: C12 = a*A11*B21 + b*C12
405-
unsigned int offa_A5 = 0;
406-
unsigned int offb_B5 = args.N / 2;
407-
unsigned int offc_C5 = args.ldc*args.N / 2;
408-
409-
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &(args.beta));
410-
assert(error == CL_SUCCESS);
411-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A5);
412-
assert(error == CL_SUCCESS);
413-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B5);
414-
assert(error == CL_SUCCESS);
415-
error = clSetKernelArg(Kernel[0], 13, sizeof(cl_uint), &offc_C5);
416-
assert(error == CL_SUCCESS);
417-
418-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
419-
gs, m_variantBig1024->ls, 0, NULL, NULL);
420-
assert(error == CL_SUCCESS);
421-
422-
// GEMM #6: C12 = a*A12*B22 + 1*C12
423-
unsigned int offa_A6 = args.lda*args.K / 4;
424-
unsigned int offb_B6 = args.ldb*args.K / 4 + args.N / 2;
425-
426-
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &betaone);
427-
assert(error == CL_SUCCESS);
428-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A6);
429-
assert(error == CL_SUCCESS);
430-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B6);
431-
assert(error == CL_SUCCESS);
432-
433-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
434-
gs, m_variantBig1024->ls, 0, NULL, NULL);
435-
assert(error == CL_SUCCESS);
436-
437-
// GEMM #7: C12 = a*A13*B23 + 1*C12
438-
unsigned int offa_A7 = args.lda*args.K / 4 * 2;
439-
unsigned int offb_B7 = args.ldb*args.K / 4 * 2 + args.N / 2;
440-
441-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A7);
442-
assert(error == CL_SUCCESS);
443-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B7);
444-
assert(error == CL_SUCCESS);
445-
446-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
447-
gs, m_variantBig1024->ls, 0, NULL, NULL);
448-
assert(error == CL_SUCCESS);
449-
450-
// GEMM #8: C12 = a*A14*B24 + 1*C12 now we are done with C12
451-
unsigned int offa_A8 = args.lda*args.K / 4 * 3;
452-
unsigned int offb_B8 = args.ldb*args.K / 4 * 3 + args.N / 2;
453-
454-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A8);
455-
assert(error == CL_SUCCESS);
456-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B8);
457-
assert(error == CL_SUCCESS);
458-
459-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
460-
gs, m_variantBig1024->ls, 0, NULL, NULL);
461-
assert(error == CL_SUCCESS);
462-
463-
// GEMM #9: C21 = a*A21*B11 + b*C21
464-
unsigned int offa_A9 = args.M / 2;
465-
unsigned int offb_B9 = 0;
466-
unsigned int offc_C9 = args.M / 2;
467-
468-
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &(args.beta));
469-
assert(error == CL_SUCCESS);
470-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A9);
471-
assert(error == CL_SUCCESS);
472-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B9);
473-
assert(error == CL_SUCCESS);
474-
error = clSetKernelArg(Kernel[0], 13, sizeof(cl_uint), &offc_C9);
475-
assert(error == CL_SUCCESS);
476-
477-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
478-
gs, m_variantBig1024->ls, 0, NULL, NULL);
479-
assert(error == CL_SUCCESS);
480-
481-
// GEMM #10: C21 = a*A22*B12 + 1*C21
482-
483-
unsigned int offa_A10 = args.lda*args.K / 4 + args.M / 2;
484-
unsigned int offb_B10 = args.ldb*args.K / 4;
485-
486-
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &betaone);
487-
assert(error == CL_SUCCESS);
488-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A10);
489-
assert(error == CL_SUCCESS);
490-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B10);
491-
assert(error == CL_SUCCESS);
492-
493-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
494-
gs, m_variantBig1024->ls, 0, NULL, NULL);
495-
assert(error == CL_SUCCESS);
496-
497-
// GEMM #11: C21 = a*A23*B13 + 1*C21
498-
499-
unsigned int offa_A11 = args.lda*args.K / 4 * 2 + args.M / 2;
500-
unsigned int offb_B11 = args.ldb*args.K / 4 * 2;
501-
502-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A11);
503-
assert(error == CL_SUCCESS);
504-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B11);
505-
assert(error == CL_SUCCESS);
506-
507-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
508-
gs, m_variantBig1024->ls, 0, NULL, NULL);
509-
assert(error == CL_SUCCESS);
510-
511-
// GEMM #12: C21 = a*A24*B14 + 1*C21 now we are done with C21
512-
513-
unsigned int offa_A12 = args.lda*args.K / 4 * 3 + args.M / 2;
514-
unsigned int offb_B12 = args.ldb*args.K / 4 * 3;
515-
516-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A12);
517-
assert(error == CL_SUCCESS);
518-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B12);
519-
assert(error == CL_SUCCESS);
520-
521-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
522-
gs, m_variantBig1024->ls, 0, NULL, NULL);
523-
assert(error == CL_SUCCESS);
524-
525-
// GEMM #13: C22 = a*A21*B21 + b*C22
526-
unsigned int offa_A13 = args.M / 2;
527-
unsigned int offb_B13 = args.N / 2;
528-
unsigned int offc_C13 = args.ldc*args.N / 2 + args.M / 2;
529-
530-
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &(args.beta));
531-
assert(error == CL_SUCCESS);
532-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A13);
533-
assert(error == CL_SUCCESS);
534-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B13);
535-
assert(error == CL_SUCCESS);
536-
error = clSetKernelArg(Kernel[0], 13, sizeof(cl_uint), &offc_C13);
537-
assert(error == CL_SUCCESS);
538-
539-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
540-
gs, m_variantBig1024->ls, 0, NULL, NULL);
541-
assert(error == CL_SUCCESS);
542-
543-
// #14: C22 = a*A22*B22 + 1*C22
544-
unsigned int offa_A14 = args.lda*args.K / 4 + args.M / 2;
545-
unsigned int offb_B14 = args.ldb*args.K / 4 + args.N / 2;
546-
547-
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &betaone);
548-
assert(error == CL_SUCCESS);
549-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A14);
550-
assert(error == CL_SUCCESS);
551-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B14);
552-
assert(error == CL_SUCCESS);
553-
554-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
555-
gs, m_variantBig1024->ls, 0, NULL, NULL);
556-
assert(error == CL_SUCCESS);
557-
558-
// #15: C22 = a*A23*B23 + 1*C22
559-
unsigned int offa_A15 = args.lda*args.K / 4 * 2 + args.M / 2;
560-
unsigned int offb_B15 = args.ldb*args.K / 4 * 2 + args.N / 2;
561-
562-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A15);
563-
assert(error == CL_SUCCESS);
564-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B15);
565-
assert(error == CL_SUCCESS);
566-
567-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
568-
gs, m_variantBig1024->ls, 0, NULL, NULL);
569-
assert(error == CL_SUCCESS);
570-
571-
// #16: C22 = a*A24*B24 + 1*C22
572-
unsigned int offa_A16 = args.lda*args.K / 4 * 3 + args.M / 2;
573-
unsigned int offb_B16 = args.ldb*args.K / 4 * 3 + args.N / 2;
574-
575-
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A16);
576-
assert(error == CL_SUCCESS);
577-
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B16);
578-
assert(error == CL_SUCCESS);
360+
for (int M_split_index = 0; M_split_index < M_split_factor; M_split_index++)
361+
{
362+
//2 groups of GEMMs splited by M from example
363+
for (int N_split_index = 0; N_split_index < N_split_factor; N_split_index++)
364+
{
365+
//2 groups of GEMMs splited by N from example
366+
unsigned int offc_C = args.ldc*args.N / N_split_factor * N_split_index + args.M / M_split_factor * M_split_index + args.offC;
367+
error = clSetKernelArg(Kernel[0], 13, sizeof(cl_uint), &offc_C);
368+
assert(error == CL_SUCCESS);
369+
370+
for (int K_split_index = 0; K_split_index < K_split_factor; K_split_index++)
371+
{
372+
//4 GEMMs splited by K from example
373+
unsigned int offa_A = (args.M / M_split_factor * M_split_index) + (args.lda * args.K / K_split_factor * K_split_index) + args.offA;
374+
unsigned int offb_B = (args.N / N_split_factor * N_split_index) + (args.ldb * args.K / K_split_factor * K_split_index) + args.offB;
375+
error = clSetKernelArg(Kernel[0], 11, sizeof(cl_uint), &offa_A);
376+
assert(error == CL_SUCCESS);
377+
error = clSetKernelArg(Kernel[0], 12, sizeof(cl_uint), &offb_B);
378+
assert(error == CL_SUCCESS);
379+
380+
if (K_split_index == 0)
381+
{
382+
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &(args.beta));
383+
assert(error == CL_SUCCESS);
384+
385+
if (M_split_index == 0 && N_split_index == 0)
386+
{
387+
//very first GEMM call
388+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
389+
gs, m_variantBig1024->ls, args.numEventsInWaitList, args.eventWaitList, NULL);
390+
assert(error == CL_SUCCESS);
391+
}
392+
else
393+
{
394+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
395+
gs, m_variantBig1024->ls, 0, NULL, NULL);
396+
assert(error == CL_SUCCESS);
397+
}
398+
}
399+
else
400+
{
401+
error = clSetKernelArg(Kernel[0], 7, sizeof(cl_float), &betaone);
402+
assert(error == CL_SUCCESS);
403+
404+
if ((M_split_index == (M_split_factor - 1) ) && (N_split_index == (N_split_factor - 1)) && (K_split_index == (K_split_factor - 1)))
405+
{
406+
//very last GEMM call
407+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
408+
gs, m_variantBig1024->ls, 0, NULL, args.events);
409+
assert(error == CL_SUCCESS);
410+
}
411+
else
412+
{
413+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
414+
gs, m_variantBig1024->ls, 0, NULL, NULL);
415+
assert(error == CL_SUCCESS);
416+
}
417+
}
418+
}
419+
}
420+
}
579421

580-
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL,
581-
gs, m_variantBig1024->ls, 0, NULL, args.events);
582-
assert(error == CL_SUCCESS);
583422

584423
return error;
585424
}

0 commit comments

Comments
 (0)