@@ -5310,45 +5310,65 @@ template <bool need_check> static __global__ void
53105310#endif // __CUDA_ARCH__ >= CC_VOLTA
53115311}
53125312
5313- template <int ncols_y_template, int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5313+ #define MMVQ_NWARPS_NVIDIA 4
5314+ #define MMVQ_NWARPS_AMD_RDNA2 1
5315+ #define MMVQ_NWARPS_AMD_OLD 4
5316+
5317+ template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5318+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
5319+ __launch_bounds__ (nwarps*WARP_SIZE, 1 ) // tells the compiler to use as many registers as it wants
5320+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
53145321static __global__ void mul_mat_vec_q (
53155322 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
53165323 const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
53175324
53185325 const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
53195326
5320- const int row = blockIdx .x *blockDim .y + threadIdx .y ;
5321-
5322- if (row >= nrows_x) {
5323- return ;
5324- }
5327+ const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
5328+ const int row = blockIdx .x ;
53255329
53265330 const int blocks_per_row_x = ncols_x / qk;
53275331 const int blocks_per_col_y = nrows_y / QK8_1;
5328- const int blocks_per_warp = vdr * WARP_SIZE / qi;
5332+ const int blocks_per_iter = vdr * nwarps* WARP_SIZE / qi;
53295333
53305334// partial sum for each thread
53315335 float tmp[ncols_y_template != 0 ? ncols_y_template : 8 ] = {0 .0f };
53325336
53335337 const block_q_t * x = (const block_q_t *) vx;
53345338 const block_q8_1 * y = (const block_q8_1 *) vy;
53355339
5336- for (int i = threadIdx . x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp ) {
5340+ for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter ) {
53375341 const int ibx = row*blocks_per_row_x + i; // x block index
53385342
53395343 const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
53405344
5341- const int iqs = vdr * (threadIdx . x % (qi/vdr)); // x block quant index when casting the quants to int
5345+ const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
53425346
53435347#pragma unroll
53445348 for (int j = 0 ; j < ncols_y; ++j) {
53455349 tmp[j] += vec_dot_q_cuda (&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
53465350 }
53475351 }
53485352
5353+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y_template != 0 ? ncols_y_template : 8 ][WARP_SIZE];
5354+ if (threadIdx .y > 0 ) {
5355+ #pragma unroll
5356+ for (int j = 0 ; j < ncols_y; ++j) {
5357+ tmp_shared[threadIdx .y -1 ][j][threadIdx .x ] = tmp[j];
5358+ }
5359+ }
5360+ __syncthreads ();
5361+ if (threadIdx .y > 0 ) {
5362+ return ;
5363+ }
5364+
53495365 // sum up partial sums and write back result
53505366#pragma unroll
53515367 for (int j = 0 ; j < ncols_y; ++j) {
5368+ #pragma unroll
5369+ for (int i = 0 ; i < nwarps-1 ; ++i) {
5370+ tmp[j] += tmp_shared[i][j][threadIdx .x ];
5371+ }
53525372 tmp[j] = warp_reduce_sum (tmp[j]);
53535373
53545374 if (threadIdx .x == 0 ) {
@@ -6833,46 +6853,65 @@ static void mul_mat_vec_q_cuda(
68336853 GGML_ASSERT (ncols_x % qk == 0 );
68346854 GGML_ASSERT (ncols_y <= 4 );
68356855
6836- const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6837- const dim3 block_nums (block_num_y, 1 , 1 );
6838- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6839- switch (ncols_y) {
6840- case 1 :
6841- mul_mat_vec_q<1 , qk, qi, block_q_t , vdr, vec_dot>
6842- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6843- break ;
6844- case 2 :
6845- mul_mat_vec_q<2 , qk, qi, block_q_t , vdr, vec_dot>
6846- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6847- break ;
6848- case 3 :
6849- mul_mat_vec_q<3 , qk, qi, block_q_t , vdr, vec_dot>
6850- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6851- break ;
6852- case 4 :
6853- mul_mat_vec_q<4 , qk, qi, block_q_t , vdr, vec_dot>
6854- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6855- break ;
6856- // case 5:
6857- // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6859- // break;
6860- // case 6:
6861- // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6863- // break;
6864- // case 7:
6865- // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6867- // break;
6868- // case 8:
6869- // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6871- // break;
6856+ int id;
6857+ CUDA_CHECK (cudaGetDevice (&id));
6858+
6859+ int nwarps;
6860+ if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
6861+ nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
6862+ } else {
6863+ nwarps = MMVQ_NWARPS_NVIDIA;
6864+ }
6865+
6866+ const dim3 block_nums (nrows_x, 1 , 1 );
6867+ const dim3 block_dims (WARP_SIZE, nwarps, 1 );
6868+
6869+ switch (nwarps) {
6870+ case 1 : switch (ncols_y) {
6871+ case 1 :
6872+ mul_mat_vec_q<1 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6873+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6874+ break ;
6875+ case 2 :
6876+ mul_mat_vec_q<1 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6877+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6878+ break ;
6879+ case 3 :
6880+ mul_mat_vec_q<1 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6881+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6882+ break ;
6883+ case 4 :
6884+ mul_mat_vec_q<1 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6885+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6886+ break ;
6887+ default :
6888+ GGML_ASSERT (false );
6889+ break ;
6890+ } break ;
6891+ case 4 : switch (ncols_y) {
6892+ case 1 :
6893+ mul_mat_vec_q<4 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6894+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6895+ break ;
6896+ case 2 :
6897+ mul_mat_vec_q<4 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6898+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6899+ break ;
6900+ case 3 :
6901+ mul_mat_vec_q<4 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6902+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6903+ break ;
6904+ case 4 :
6905+ mul_mat_vec_q<4 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6906+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6907+ break ;
6908+ default :
6909+ GGML_ASSERT (false );
6910+ break ;
6911+ } break ;
6912+
68726913 default :
68736914 GGML_ASSERT (false );
6874- // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6875- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68766915 break ;
68776916 }
68786917}
0 commit comments