@@ -5,9 +5,10 @@ template <typename T, typename type_acc, int block_size>
55static __global__ void mul_mat_vec (
66 const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
77 const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
8- const int64_t row = blockIdx .x ;
9- const int64_t channel = blockIdx .z ;
10- const int tid = threadIdx .x ;
8+ const int64_t row = blockIdx .x ;
9+ const int64_t channel = blockIdx .z ;
10+ const int tid = threadIdx .x ;
11+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
1112
1213 x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
1314 y += channel *stride_channel_y;
@@ -18,8 +19,8 @@ static __global__ void mul_mat_vec(
1819 extern __shared__ char data_mmv[];
1920 float * buf_iw = (float *) data_mmv;
2021
21- if (block_size > WARP_SIZE ) {
22- if (tid < WARP_SIZE ) {
22+ if (block_size > warp_size ) {
23+ if (tid < warp_size ) {
2324 buf_iw[tid] = 0 .0f ;
2425 }
2526 __syncthreads ();
@@ -67,16 +68,16 @@ static __global__ void mul_mat_vec(
6768 static_assert (std::is_same<T, void >::value, " unsupported type" );
6869 }
6970
70- sumf = warp_reduce_sum (sumf);
71+ sumf = warp_reduce_sum<warp_size> (sumf);
7172
72- if (block_size > WARP_SIZE ) {
73- buf_iw[tid/WARP_SIZE ] = sumf;
73+ if (block_size > warp_size ) {
74+ buf_iw[tid/warp_size ] = sumf;
7475 __syncthreads ();
75- if (tid >= WARP_SIZE ) {
76+ if (tid >= warp_size ) {
7677 return ;
7778 }
7879 sumf = buf_iw[tid];
79- sumf = warp_reduce_sum (sumf);
80+ sumf = warp_reduce_sum<warp_size> (sumf);
8081 }
8182
8283 if (tid != 0 ) {
@@ -96,18 +97,27 @@ static void launch_mul_mat_vec_cuda(
9697 GGML_ASSERT (stride_row % 2 == 0 );
9798 GGML_ASSERT (nchannels_y % nchannels_x == 0 );
9899 const int64_t channel_ratio = nchannels_y / nchannels_x;
100+ int device;
101+ int warp_size;
99102
100- int64_t block_size_best = WARP_SIZE;
101- int64_t niter_best = (ncols + 2 *WARP_SIZE - 1 ) / (2 *WARP_SIZE);
102- for (int64_t block_size = 2 *WARP_SIZE; block_size <= 256 ; block_size += WARP_SIZE) {
103+ CUDA_CHECK (cudaGetDevice (&device));
104+ warp_size = ggml_cuda_info ().devices [device].warp_size ;
105+
106+ int64_t block_size_best = warp_size;
107+ int64_t niter_best = (ncols + 2 *warp_size - 1 ) / (2 *warp_size);
108+ int64_t max_block_size = 256 ;
109+ if (ggml_cuda_info ().devices [device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info ().devices [device].cc < GGML_CUDA_CC_RDNA1) {
110+ max_block_size = 128 ;
111+ }
112+ for (int64_t block_size = 2 *warp_size; block_size <= max_block_size; block_size += warp_size) {
103113 const int64_t niter = (ncols + 2 *block_size - 1 ) / (2 *block_size);
104114 if (niter < niter_best) {
105115 niter_best = niter;
106116 block_size_best = block_size;
107117 }
108118 }
109119
110- const int smem = WARP_SIZE *sizeof (float );
120+ const int smem = warp_size *sizeof (float );
111121 const dim3 block_nums (nrows, 1 , nchannels_y);
112122 const dim3 block_dims (block_size_best, 1 , 1 );
113123 switch (block_size_best) {
0 commit comments