@@ -18,8 +18,8 @@ static __global__ void mul_mat_vec(
1818 extern __shared__ char data_mmv[];
1919 float * buf_iw = (float *) data_mmv;
2020
21- if (block_size > WARP_SIZE ) {
22- if (tid < WARP_SIZE ) {
21+ if (block_size > GGML_TRUE_WARP_SIZE ) {
22+ if (tid < GGML_TRUE_WARP_SIZE ) {
2323 buf_iw[tid] = 0 .0f ;
2424 }
2525 __syncthreads ();
@@ -67,16 +67,16 @@ static __global__ void mul_mat_vec(
6767 static_assert (std::is_same<T, void >::value, " unsupported type" );
6868 }
6969
70- sumf = warp_reduce_sum (sumf);
70+ sumf = warp_reduce_sum<GGML_TRUE_WARP_SIZE> (sumf);
7171
72- if (block_size > WARP_SIZE ) {
73- buf_iw[tid/WARP_SIZE ] = sumf;
72+ if (block_size > GGML_TRUE_WARP_SIZE ) {
73+ buf_iw[tid/GGML_TRUE_WARP_SIZE ] = sumf;
7474 __syncthreads ();
75- if (tid >= WARP_SIZE ) {
75+ if (tid >= GGML_TRUE_WARP_SIZE ) {
7676 return ;
7777 }
7878 sumf = buf_iw[tid];
79- sumf = warp_reduce_sum (sumf);
79+ sumf = warp_reduce_sum<GGML_TRUE_WARP_SIZE> (sumf);
8080 }
8181
8282 if (tid != 0 ) {
@@ -96,18 +96,27 @@ static void launch_mul_mat_vec_cuda(
9696 GGML_ASSERT (stride_row % 2 == 0 );
9797 GGML_ASSERT (nchannels_y % nchannels_x == 0 );
9898 const int64_t channel_ratio = nchannels_y / nchannels_x;
99+ int device;
100+ int warp_size;
99101
100- int64_t block_size_best = WARP_SIZE;
101- int64_t niter_best = (ncols + 2 *WARP_SIZE - 1 ) / (2 *WARP_SIZE);
102- for (int64_t block_size = 2 *WARP_SIZE; block_size <= 256 ; block_size += WARP_SIZE) {
102+ CUDA_CHECK (cudaGetDevice (&device));
103+ warp_size = ggml_cuda_info ().devices [device].warp_size ;
104+
105+ int64_t block_size_best = warp_size;
106+ int64_t niter_best = (ncols + 2 *warp_size - 1 ) / (2 *warp_size);
107+ int64_t max_block_size = 256 ;
108+ if (ggml_cuda_info ().devices [device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info ().devices [device].cc < GGML_CUDA_CC_RDNA1) {
109+ max_block_size = 128 ;
110+ }
111+ for (int64_t block_size = 2 *warp_size; block_size <= max_block_size; block_size += warp_size) {
103112 const int64_t niter = (ncols + 2 *block_size - 1 ) / (2 *block_size);
104113 if (niter < niter_best) {
105114 niter_best = niter;
106115 block_size_best = block_size;
107116 }
108117 }
109118
110- const int smem = WARP_SIZE *sizeof (float );
119+ const int smem = warp_size *sizeof (float );
111120 const dim3 block_nums (nrows, 1 , nchannels_y);
112121 const dim3 block_dims (block_size_best, 1 , 1 );
113122 switch (block_size_best) {
0 commit comments