Skip to content

Commit 15c03be

Browse files
author
Timmy
committed
enable rect read/write for gemm
1 parent da0a638 commit 15c03be

File tree

3 files changed

+96
-8
lines changed

3 files changed

+96
-8
lines changed

src/client/clfunc_common.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ class clblasFunc
313313
virtual void reset_gpu_write_buffer() = 0;
314314
virtual void read_gpu_buffer() = 0;
315315
virtual void roundtrip_func() = 0;
316+
virtual void roundtrip_func_rect() {}
316317
virtual void allochostptr_roundtrip_func() {}
317318
virtual void usehostptr_roundtrip_func() {}
318319
virtual void copyhostptr_roundtrip_func() {}

src/client/clfunc_xgemm.hpp

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,89 @@ class xGemm : public clblasFunc
454454
clWaitForEvents(1, &event_);
455455
timer.Stop(timer_id);
456456
}
457+
void roundtrip_func_rect()
458+
{
459+
timer.Start(timer_id);
460+
cl_int err;
461+
//rect
462+
size_t a_buffer_origin[3] = {0,0,0};
463+
size_t a_host_origin[3] = {0,0,0};
464+
size_t a_region[3] = {buffer_.m_*sizeof(T),buffer_.k_,1};
465+
size_t a_buffer_row_pitch=0*sizeof(T);//lda
466+
size_t a_buffer_slice_pitch=0;
467+
size_t a_host_row_pitch=buffer_.lda_*sizeof(T);
468+
size_t a_host_slice_pitch=0;
469+
470+
size_t b_buffer_origin[3] = {0,0,0};
471+
size_t b_host_origin[3] = {0,0,0};
472+
size_t b_region[3] = {buffer_.k_*sizeof(T),buffer_.n_,1};
473+
size_t b_buffer_row_pitch=0*sizeof(T);//ldb
474+
size_t b_buffer_slice_pitch=0;
475+
size_t b_host_row_pitch=buffer_.ldb_*sizeof(T);
476+
size_t b_host_slice_pitch=0;
477+
478+
size_t c_buffer_origin[3] = {0,0,0};
479+
size_t c_host_origin[3] = {0,0,0};
480+
size_t c_region[3] = {buffer_.m_*sizeof(T),buffer_.n_,1};
481+
size_t c_buffer_row_pitch=0*sizeof(T);//ldc
482+
size_t c_buffer_slice_pitch=0;
483+
size_t c_host_row_pitch=buffer_.ldc_*sizeof(T);
484+
size_t c_host_slice_pitch=0;
485+
486+
buffer_.buf_a_ = clCreateBuffer(ctx_, CL_MEM_READ_ONLY,
487+
(buffer_.k_*buffer_.m_ +
488+
buffer_.offA_) * sizeof(T),
489+
NULL, &err);
490+
491+
buffer_.buf_b_ = clCreateBuffer(ctx_, CL_MEM_READ_ONLY,
492+
(buffer_.k_ * buffer_.n_ +
493+
buffer_.offB_) * sizeof(T),
494+
NULL, &err);
495+
496+
buffer_.buf_c_ = clCreateBuffer(ctx_, CL_MEM_READ_WRITE,
497+
(buffer_.m_ * buffer_.n_ +
498+
buffer_.offC_) * sizeof(T),
499+
NULL, &err);
500+
/*
501+
err = clEnqueueWriteBuffer(queue_, buffer_.buf_a_, CL_TRUE,
502+
buffer_.offA_ * sizeof(T),
503+
buffer_.lda_ * buffer_.a_num_vectors_ *
504+
sizeof(T),
505+
buffer_.a_, 0, NULL, NULL);
506+
507+
err = clEnqueueWriteBuffer(queue_, buffer_.buf_b_, CL_TRUE,
508+
buffer_.offB_ * sizeof(T),
509+
buffer_.ldb_ * buffer_.b_num_vectors_ *
510+
sizeof(T),
511+
buffer_.b_, 0, NULL, NULL);
512+
513+
err = clEnqueueWriteBuffer(queue_, buffer_.buf_c_, CL_TRUE,
514+
buffer_.offC_ * sizeof(T),
515+
buffer_.ldc_ * buffer_.c_num_vectors_ *
516+
sizeof(T),
517+
buffer_.c_, 0, NULL, NULL);*/
518+
err = clEnqueueWriteBufferRect(queue_, buffer_.buf_a_, CL_TRUE, a_buffer_origin, a_host_origin, a_region, a_buffer_row_pitch,
519+
a_buffer_slice_pitch, a_host_row_pitch, a_host_slice_pitch, buffer_.a_, 0, NULL, NULL);
520+
err = clEnqueueWriteBufferRect(queue_, buffer_.buf_b_, CL_TRUE, b_buffer_origin, b_host_origin, b_region, b_buffer_row_pitch,
521+
b_buffer_slice_pitch, b_host_row_pitch, b_host_slice_pitch, buffer_.b_, 0, NULL, NULL);
522+
err = clEnqueueWriteBufferRect(queue_, buffer_.buf_c_, CL_TRUE, c_buffer_origin, c_host_origin, c_region, c_buffer_row_pitch,
523+
c_buffer_slice_pitch, c_host_row_pitch, c_host_slice_pitch, buffer_.c_, 0, NULL, NULL);
524+
525+
buffer_.lda_ = 0;
526+
buffer_.ldb_ = 0;
527+
buffer_.ldc_ = 0;
528+
xGemm_Function(false);
529+
/*
530+
err = clEnqueueReadBuffer(queue_, buffer_.buf_c_, CL_TRUE,
531+
buffer_.offC_ * sizeof(T), buffer_.ldc_ * buffer_.c_num_vectors_ *
532+
sizeof(T),
533+
buffer_.c_, 0, NULL, &event_);
534+
*/
535+
err = ::clEnqueueReadBufferRect(queue_, buffer_.buf_c_, CL_TRUE, c_buffer_origin, c_host_origin, c_region, c_buffer_row_pitch,
536+
c_buffer_slice_pitch, c_host_row_pitch, c_host_slice_pitch, buffer_.c_, 0, NULL, &event_);
537+
clWaitForEvents(1, &event_);
538+
timer.Stop(timer_id);
539+
}
457540
void allochostptr_roundtrip_func()
458541
{
459542
timer.Start(timer_id);
@@ -528,12 +611,7 @@ class xGemm : public clblasFunc
528611
(buffer_.ldc_ * buffer_.c_num_vectors_ +
529612
buffer_.offC_) * sizeof(T),
530613
buffer_.c_, &err);
531-
xGemm_Function(false);
532-
err = clEnqueueReadBuffer(queue_, buffer_.buf_c_, CL_TRUE,
533-
buffer_.offC_ * sizeof(T), buffer_.ldc_ * buffer_.c_num_vectors_ *
534-
sizeof(T),
535-
buffer_.c_, 0, NULL, &event_);
536-
clWaitForEvents(1, &event_);
614+
xGemm_Function(true);
537615
timer.Stop(timer_id);
538616
}
539617
void copyhostptr_roundtrip_func()
@@ -554,7 +632,12 @@ class xGemm : public clblasFunc
554632
(buffer_.ldc_ * buffer_.c_num_vectors_ +
555633
buffer_.offC_) * sizeof(T),
556634
buffer_.c_, &err);
557-
xGemm_Function(true);
635+
xGemm_Function(false);
636+
err = clEnqueueReadBuffer(queue_, buffer_.buf_c_, CL_TRUE,
637+
buffer_.offC_ * sizeof(T), buffer_.ldc_ * buffer_.c_num_vectors_ *
638+
sizeof(T),
639+
buffer_.c_, 0, NULL, &event_);
640+
clWaitForEvents(1, &event_);
558641
timer.Stop(timer_id);
559642
}
560643
void usepersismem_roundtrip_func()

src/client/client.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ int main(int argc, char *argv[])
102102
( "diag", po::value<int>( &diag_option )->default_value(0), "0 = unit diagonal, 1 = non unit diagonal. only used with [list of function families]" ) // xtrsm xtrmm
103103
( "profile,p", po::value<cl_uint>( &profileCount )->default_value(20), "Time and report the kernel speed (default: profiling off)" )
104104
( "roundtrip", po::value<std::string>( &roundtrip )->default_value("noroundtrip"),"including the time of OpenCL memory allocation and transportation; options:roundtrip, noroundtrip(default)")
105-
( "memalloc", po::value<std::string>( &memalloc )->default_value("default"),"setting the memory allocation flags for OpenCL; would not take effect if roundtrip time is not measured; options:default(default),alloc_host_ptr,use_host_ptr,copy_host_ptr,use_persistent_mem_amd")
105+
( "memalloc", po::value<std::string>( &memalloc )->default_value("default"),"setting the memory allocation flags for OpenCL; would not take effect if roundtrip time is not measured; options:default(default),alloc_host_ptr,use_host_ptr,copy_host_ptr,use_persistent_mem_amd,rect_mem")
106106
;
107107

108108
po::variables_map vm;
@@ -534,6 +534,10 @@ int main(int argc, char *argv[])
534534
{
535535
my_function->usepersismem_roundtrip_func();
536536
}
537+
else if (memalloc=="rect_mem")
538+
{
539+
my_function->roundtrip_func_rect();
540+
}
537541
//my_function->reset_gpu_write_buffer();
538542
my_function->releaseGPUBuffer_deleteCPUBuffer();
539543
}

0 commit comments

Comments
 (0)