Skip to content

Commit 7471dcb

Browse files
author
Kent Knox
committed
Merge pull request #36 from TimmyLiu/develop
enable rect read/write for gemm in client (good performance when lda, ldb and ldc is bigger than m, n and k); Also bug fixes of hemm and symm hanging and ssyr2k crashing when using the tuning kdb files
2 parents 900b211 + 342cc8f commit 7471dcb

File tree

5 files changed

+133
-16
lines changed

5 files changed

+133
-16
lines changed

src/client/clfunc_common.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ class clblasFunc
313313
virtual void reset_gpu_write_buffer() = 0;
314314
virtual void read_gpu_buffer() = 0;
315315
virtual void roundtrip_func() = 0;
316+
virtual void roundtrip_func_rect() {}
316317
virtual void allochostptr_roundtrip_func() {}
317318
virtual void usehostptr_roundtrip_func() {}
318319
virtual void copyhostptr_roundtrip_func() {}

src/client/clfunc_xgemm.hpp

Lines changed: 104 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,103 @@ class xGemm : public clblasFunc
454454
clWaitForEvents(1, &event_);
455455
timer.Stop(timer_id);
456456
}
457+
void roundtrip_func_rect()
458+
{
459+
timer.Start(timer_id);
460+
cl_int err;
461+
//rect
462+
size_t a_buffer_origin[3] = {0,0,0};
463+
size_t a_host_origin[3] = {0,0,0};
464+
size_t a_region[3] = {buffer_.m_*sizeof(T),buffer_.k_,1};
465+
size_t a_buffer_row_pitch=0*sizeof(T);//lda
466+
size_t a_buffer_slice_pitch=0;
467+
size_t a_host_row_pitch=buffer_.lda_*sizeof(T);
468+
size_t a_host_slice_pitch=0;
469+
470+
size_t b_buffer_origin[3] = {0,0,0};
471+
size_t b_host_origin[3] = {0,0,0};
472+
size_t b_region[3] = {buffer_.k_*sizeof(T),buffer_.n_,1};
473+
size_t b_buffer_row_pitch=0*sizeof(T);//ldb
474+
size_t b_buffer_slice_pitch=0;
475+
size_t b_host_row_pitch=buffer_.ldb_*sizeof(T);
476+
size_t b_host_slice_pitch=0;
477+
478+
size_t c_buffer_origin[3] = {0,0,0};
479+
size_t c_host_origin[3] = {0,0,0};
480+
size_t c_region[3] = {buffer_.m_*sizeof(T),buffer_.n_,1};
481+
size_t c_buffer_row_pitch=0*sizeof(T);//ldc
482+
size_t c_buffer_slice_pitch=0;
483+
size_t c_host_row_pitch=buffer_.ldc_*sizeof(T);
484+
size_t c_host_slice_pitch=0;
485+
486+
buffer_.buf_a_ = clCreateBuffer(ctx_, CL_MEM_READ_ONLY,
487+
(buffer_.k_*buffer_.m_ +
488+
buffer_.offA_) * sizeof(T),
489+
NULL, &err);
490+
491+
buffer_.buf_b_ = clCreateBuffer(ctx_, CL_MEM_READ_ONLY,
492+
(buffer_.k_ * buffer_.n_ +
493+
buffer_.offB_) * sizeof(T),
494+
NULL, &err);
495+
496+
buffer_.buf_c_ = clCreateBuffer(ctx_, CL_MEM_READ_WRITE,
497+
(buffer_.m_ * buffer_.n_ +
498+
buffer_.offC_) * sizeof(T),
499+
NULL, &err);
500+
/*
501+
err = clEnqueueWriteBuffer(queue_, buffer_.buf_a_, CL_TRUE,
502+
buffer_.offA_ * sizeof(T),
503+
buffer_.lda_ * buffer_.a_num_vectors_ *
504+
sizeof(T),
505+
buffer_.a_, 0, NULL, NULL);
506+
507+
err = clEnqueueWriteBuffer(queue_, buffer_.buf_b_, CL_TRUE,
508+
buffer_.offB_ * sizeof(T),
509+
buffer_.ldb_ * buffer_.b_num_vectors_ *
510+
sizeof(T),
511+
buffer_.b_, 0, NULL, NULL);
512+
513+
err = clEnqueueWriteBuffer(queue_, buffer_.buf_c_, CL_TRUE,
514+
buffer_.offC_ * sizeof(T),
515+
buffer_.ldc_ * buffer_.c_num_vectors_ *
516+
sizeof(T),
517+
buffer_.c_, 0, NULL, NULL);*/
518+
err = clEnqueueWriteBufferRect(queue_, buffer_.buf_a_, CL_TRUE, a_buffer_origin, a_host_origin, a_region, a_buffer_row_pitch,
519+
a_buffer_slice_pitch, a_host_row_pitch, a_host_slice_pitch, buffer_.a_, 0, NULL, NULL);
520+
err = clEnqueueWriteBufferRect(queue_, buffer_.buf_b_, CL_TRUE, b_buffer_origin, b_host_origin, b_region, b_buffer_row_pitch,
521+
b_buffer_slice_pitch, b_host_row_pitch, b_host_slice_pitch, buffer_.b_, 0, NULL, NULL);
522+
err = clEnqueueWriteBufferRect(queue_, buffer_.buf_c_, CL_TRUE, c_buffer_origin, c_host_origin, c_region, c_buffer_row_pitch,
523+
c_buffer_slice_pitch, c_host_row_pitch, c_host_slice_pitch, buffer_.c_, 0, NULL, NULL);
524+
525+
if(buffer_.trans_a_==clblasNoTrans)
526+
{
527+
buffer_.lda_=buffer_.m_;
528+
}
529+
else
530+
{
531+
buffer_.lda_=buffer_.k_;
532+
}
533+
if(buffer_.trans_b_==clblasNoTrans)
534+
{
535+
buffer_.ldb_=buffer_.k_;
536+
}
537+
else
538+
{
539+
buffer_.ldb_=buffer_.n_;
540+
}
541+
buffer_.ldc_=buffer_.m_;
542+
xGemm_Function(false);
543+
/*
544+
err = clEnqueueReadBuffer(queue_, buffer_.buf_c_, CL_TRUE,
545+
buffer_.offC_ * sizeof(T), buffer_.ldc_ * buffer_.c_num_vectors_ *
546+
sizeof(T),
547+
buffer_.c_, 0, NULL, &event_);
548+
*/
549+
err = ::clEnqueueReadBufferRect(queue_, buffer_.buf_c_, CL_TRUE, c_buffer_origin, c_host_origin, c_region, c_buffer_row_pitch,
550+
c_buffer_slice_pitch, c_host_row_pitch, c_host_slice_pitch, buffer_.c_, 0, NULL, &event_);
551+
clWaitForEvents(1, &event_);
552+
timer.Stop(timer_id);
553+
}
457554
void allochostptr_roundtrip_func()
458555
{
459556
timer.Start(timer_id);
@@ -528,12 +625,7 @@ class xGemm : public clblasFunc
528625
(buffer_.ldc_ * buffer_.c_num_vectors_ +
529626
buffer_.offC_) * sizeof(T),
530627
buffer_.c_, &err);
531-
xGemm_Function(false);
532-
err = clEnqueueReadBuffer(queue_, buffer_.buf_c_, CL_TRUE,
533-
buffer_.offC_ * sizeof(T), buffer_.ldc_ * buffer_.c_num_vectors_ *
534-
sizeof(T),
535-
buffer_.c_, 0, NULL, &event_);
536-
clWaitForEvents(1, &event_);
628+
xGemm_Function(true);
537629
timer.Stop(timer_id);
538630
}
539631
void copyhostptr_roundtrip_func()
@@ -554,7 +646,12 @@ class xGemm : public clblasFunc
554646
(buffer_.ldc_ * buffer_.c_num_vectors_ +
555647
buffer_.offC_) * sizeof(T),
556648
buffer_.c_, &err);
557-
xGemm_Function(true);
649+
xGemm_Function(false);
650+
err = clEnqueueReadBuffer(queue_, buffer_.buf_c_, CL_TRUE,
651+
buffer_.offC_ * sizeof(T), buffer_.ldc_ * buffer_.c_num_vectors_ *
652+
sizeof(T),
653+
buffer_.c_, 0, NULL, &event_);
654+
clWaitForEvents(1, &event_);
558655
timer.Stop(timer_id);
559656
}
560657
void usepersismem_roundtrip_func()

src/client/client.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ int main(int argc, char *argv[])
102102
( "diag", po::value<int>( &diag_option )->default_value(0), "0 = unit diagonal, 1 = non unit diagonal. only used with [list of function families]" ) // xtrsm xtrmm
103103
( "profile,p", po::value<cl_uint>( &profileCount )->default_value(20), "Time and report the kernel speed (default: profiling off)" )
104104
( "roundtrip", po::value<std::string>( &roundtrip )->default_value("noroundtrip"),"including the time of OpenCL memory allocation and transportation; options:roundtrip, noroundtrip(default)")
105-
( "memalloc", po::value<std::string>( &memalloc )->default_value("default"),"setting the memory allocation flags for OpenCL; would not take effect if roundtrip time is not measured; options:default(default),alloc_host_ptr,use_host_ptr,copy_host_ptr,use_persistent_mem_amd")
105+
( "memalloc", po::value<std::string>( &memalloc )->default_value("default"),"setting the memory allocation flags for OpenCL; would not take effect if roundtrip time is not measured; options:default(default),alloc_host_ptr,use_host_ptr,copy_host_ptr,use_persistent_mem_amd,rect_mem")
106106
;
107107

108108
po::variables_map vm;
@@ -534,6 +534,10 @@ int main(int argc, char *argv[])
534534
{
535535
my_function->usepersismem_roundtrip_func();
536536
}
537+
else if (memalloc=="rect_mem")
538+
{
539+
my_function->roundtrip_func_rect();
540+
}
537541
//my_function->reset_gpu_write_buffer();
538542
my_function->releaseGPUBuffer_deleteCPUBuffer();
539543
}

src/library/blas/generic/solution_seq_make.c

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,9 +1435,12 @@ getStepGranulation(SolutionStep *step)
14351435
}
14361436
}
14371437

1438-
status = getGranularityInfo(&step->device, mempat->name,
1439-
step->args.dtype, step->extraFlags,
1440-
(int)MNK, dims, &step->pgran, &time);
1438+
if( step->funcID != CLBLAS_GEMM2 )
1439+
{
1440+
status = getGranularityInfo(&step->device, mempat->name,
1441+
step->args.dtype, step->extraFlags,
1442+
(int)MNK, dims, &step->pgran, &time);
1443+
}
14411444
/*
14421445
* Disable blocking for implementations dealing with cache reads
14431446
* from the global memory

src/library/blas/gens/syrxk.c

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <string.h>
2323
#include <stdio.h>
24+
#include <stdlib.h>
2425
#include <assert.h>
2526

2627
#include <clBLAS.h>
@@ -1219,10 +1220,11 @@ genUpdateGenericDiagTile(
12191220
// type of the vectorized coordinates
12201221
Kstring vctype;
12211222
Kstring constOffs, constShifts, constMasks;
1222-
unsigned int i, j, nops;
1223+
unsigned int i, j, nops,size;
12231224
unsigned int maxFetches = 0;
12241225
const char *yname, *xname;
12251226
const char *ldcName;
1227+
char hexadec[2];
12261228

12271229
batch = createStmtBatch();
12281230
if (batch == NULL) {
@@ -1253,6 +1255,14 @@ genUpdateGenericDiagTile(
12531255
tifl = (isUpper) ? TILE_ITER_BACKWARD_ROWS :
12541256
TILE_ITER_BACKWARD_COLS;
12551257
iterInit(&iter, &tileTempC, 1, tifl);
1258+
nops = 0;
1259+
while (!iterIsEnd(&iter)) {
1260+
nops++;
1261+
size = nops / nrCols;
1262+
iterIterate(&iter);
1263+
}
1264+
1265+
iterInit(&iter, &tileTempC, 1, tifl);
12561266

12571267
initTmpResTile(&tileTempC, gset, true);
12581268

@@ -1316,7 +1326,7 @@ genUpdateGenericDiagTile(
13161326
maxFetches = umin(maxFetches, i);
13171327

13181328
// declare vectorized coordinates
1319-
declareDiagUpresIndexedVars(ctx, vctype.buf, "cc", tempRows);
1329+
declareDiagUpresIndexedVars(ctx, vctype.buf, "cc", size);
13201330

13211331
/*
13221332
* real y coordinate, offset mask and
@@ -1326,8 +1336,8 @@ genUpdateGenericDiagTile(
13261336
"unsigned int mask;\n"
13271337
"int hit;\n");
13281338
if (withBeta) {
1329-
declareDiagUpresIndexedVars(ctx, typeName, "alphaNew", tempRows);
1330-
declareDiagUpresIndexedVars(ctx, typeName, "betaNew", tempRows);
1339+
declareDiagUpresIndexedVars(ctx, typeName, "alphaNew", size);
1340+
declareDiagUpresIndexedVars(ctx, typeName, "betaNew", size);
13311341
}
13321342

13331343
// declare tile
@@ -1443,7 +1453,9 @@ genUpdateGenericDiagTile(
14431453
ksprintf(&kstr, "cc%u", i);
14441454
}
14451455
else {
1446-
ksprintf(&kstr, "cc%u.s%u", i, iter.col);
1456+
snprintf(hexadec, sizeof(char)*2, "%x", iter.col);
1457+
//itoa(iter.col, hexadec, 16);
1458+
ksprintf(&kstr, "cc%u.s%s", i, hexadec);
14471459
}
14481460

14491461
// prepare multipliers and fetch

0 commit comments

Comments
 (0)