@@ -47,36 +47,110 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
4747 1 ;
4848}
4949
50+ enum mmvq_parameter_table_id {
51+ MMVQ_PARAMETERS_GENERIC = 0 ,
52+ MMVQ_PARAMETERS_GCN,
53+ MMVQ_PARAMETERS_RDNA2
54+ };
55+
56+ static constexpr __device__ mmvq_parameter_table_id get_device_table_id () {
57+ #if defined(RDNA2) || defined(RDNA3)
58+ return MMVQ_PARAMETERS_RDNA2;
59+ #elif defined(GCN) || defined(CDNA)
60+ return MMVQ_PARAMETERS_GCN;
61+ #else
62+ return MMVQ_PARAMETERS_GENERIC;
63+ #endif
64+ }
65+
66+ static __host__ mmvq_parameter_table_id get_device_table_id (int cc) {
67+ if (GGML_CUDA_CC_IS_RDNA2 (cc) || GGML_CUDA_CC_IS_RDNA3 (cc)) {
68+ return MMVQ_PARAMETERS_RDNA2;
69+ }
70+ if (GGML_CUDA_CC_IS_GCN (cc) || GGML_CUDA_CC_IS_CDNA (cc)) {
71+ return MMVQ_PARAMETERS_GCN;
72+ }
73+ return MMVQ_PARAMETERS_GENERIC;
74+ }
75+
76+ static constexpr __host__ __device__ int calc_nwarps (int ncols_y, mmvq_parameter_table_id table_id) {
77+ if (table_id == MMVQ_PARAMETERS_GENERIC) {
78+ switch (ncols_y) {
79+ case 1 :
80+ case 2 :
81+ case 3 :
82+ case 4 :
83+ return 4 ;
84+ case 5 :
85+ case 6 :
86+ case 7 :
87+ case 8 :
88+ return 2 ;
89+ default :
90+ return 1 ;
91+ }
92+ } else if (table_id == MMVQ_PARAMETERS_GCN) {
93+ switch (ncols_y) {
94+ case 1 :
95+ case 2 :
96+ case 3 :
97+ case 4 :
98+ return 2 ;
99+ case 5 :
100+ case 6 :
101+ case 7 :
102+ case 8 :
103+ default :
104+ return 1 ;
105+ }
106+ }
107+ return 1 ;
108+ }
109+
110+ static constexpr __host__ __device__ int calc_rows_per_block (int ncols_y, int table_id) {
111+ if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
112+ switch (ncols_y) {
113+ case 1 :
114+ return 1 ;
115+ case 2 :
116+ case 3 :
117+ case 4 :
118+ case 5 :
119+ case 6 :
120+ case 7 :
121+ case 8 :
122+ return 2 ;
123+ default :
124+ return 1 ;
125+ }
126+ }
127+ return 1 ;
128+ }
129+
50130template <ggml_type type, int ncols_y>
51- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52131// tell the compiler to use as many registers as it wants, see nwarps definition below
53- __launch_bounds__ ((ncols_y <= 4 ? 4 : 2 )*WARP_SIZE, 1)
54- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
132+ __launch_bounds__ (calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
55133static __global__ void mul_mat_vec_q(
56134 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57135 const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
58136
59137 constexpr int qk = ggml_cuda_type_traits<type>::qk;
60138 constexpr int qi = ggml_cuda_type_traits<type>::qi;
61139 constexpr int vdr = get_vdr_mmvq (type);
140+ constexpr mmvq_parameter_table_id table_id = get_device_table_id ();
141+ constexpr int nwarps = calc_nwarps (ncols_y, table_id);
142+ constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_y, table_id);
143+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
62144
63145 constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda (type);
64146
65- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66- constexpr int nwarps = 1 ;
67- constexpr int rows_per_cuda_block = 1 ;
68- #else
69- constexpr int nwarps = ncols_y <= 4 ? 4 : 2 ;
70- constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2 ;
71- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72-
73- const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
147+ const int tid = warp_size*threadIdx .y + threadIdx .x ;
74148 const int row0 = rows_per_cuda_block*blockIdx .x ;
75149 const int blocks_per_row_x = ncols_x / qk;
76150 const int blocks_per_col_y = nrows_y / QK8_1;
77- constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
151+ constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
78152
79- // partial sum for each thread
153+ // partial sum for each thread
80154 float tmp[ncols_y][rows_per_cuda_block] = {0 .0f };
81155
82156 const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
96170 }
97171 }
98172
99- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][WARP_SIZE ];
173+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][warp_size ];
100174 if (threadIdx .y > 0 ) {
101175#pragma unroll
102176 for (int j = 0 ; j < ncols_y; ++j) {
@@ -120,7 +194,7 @@ static __global__ void mul_mat_vec_q(
120194 for (int l = 0 ; l < nwarps-1 ; ++l) {
121195 tmp[j][i] += tmp_shared[l][j][i][threadIdx .x ];
122196 }
123- tmp[j][i] = warp_reduce_sum (tmp[j][i]);
197+ tmp[j][i] = warp_reduce_sum<warp_size> (tmp[j][i]);
124198 }
125199
126200 if (threadIdx .x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx .x < nrows_dst)) {
@@ -129,6 +203,13 @@ static __global__ void mul_mat_vec_q(
129203 }
130204}
131205
206+ static std::pair<dim3 , dim3 > calc_launch_params (const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
207+ const int64_t nblocks = (nrows_x + calc_rows_per_block (ncols_y, table_id) - 1 ) / calc_rows_per_block (ncols_y, table_id);
208+ const dim3 block_nums (nblocks, 1 , 1 );
209+ const dim3 block_dims (warp_size, calc_nwarps (ncols_y, table_id), 1 );
210+ return {block_nums, block_dims};
211+ }
212+
132213template <ggml_type type>
133214static void mul_mat_vec_q_cuda (
134215 const void * vx, const void * vy, float * dst,
@@ -137,65 +218,67 @@ static void mul_mat_vec_q_cuda(
137218 GGML_ASSERT (ncols_x % ggml_blck_size (type) == 0 );
138219 GGML_ASSERT (ncols_y <= MMVQ_MAX_BATCH_SIZE);
139220
140- int id = ggml_cuda_get_device ();
141-
142- int64_t nwarps = 1 ;
143- int64_t rows_per_cuda_block = 1 ;
144-
145- if (ggml_cuda_info ().devices [id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146- switch (ncols_y) {
147- case 1 :
148- nwarps = 4 ;
149- rows_per_cuda_block = 1 ;
150- break ;
151- case 2 :
152- case 3 :
153- case 4 :
154- nwarps = 4 ;
155- rows_per_cuda_block = 2 ;
156- break ;
157- case 5 :
158- case 6 :
159- case 7 :
160- case 8 :
161- nwarps = 2 ;
162- rows_per_cuda_block = 2 ;
163- break ;
164- default :
165- GGML_ABORT (" fatal error" );
166- break ;
167- }
168- }
169-
170- const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1 ) / rows_per_cuda_block;
171- const dim3 block_nums (nblocks, 1 , 1 );
172- const dim3 block_dims (WARP_SIZE, nwarps, 1 );
221+ const int device = ggml_cuda_get_device ();
222+ const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
223+ const mmvq_parameter_table_id table_id = get_device_table_id (ggml_cuda_info ().devices [device].cc );
173224
174225 switch (ncols_y) {
175226 case 1 :
176- mul_mat_vec_q<type, 1 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
227+ {
228+ constexpr int c_ncols_y = 1 ;
229+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
230+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
177231 break ;
232+ }
178233 case 2 :
179- mul_mat_vec_q<type, 2 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
234+ {
235+ constexpr int c_ncols_y = 2 ;
236+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
237+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
180238 break ;
239+ }
181240 case 3 :
182- mul_mat_vec_q<type, 3 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
241+ {
242+ constexpr int c_ncols_y = 3 ;
243+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
244+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
183245 break ;
246+ }
184247 case 4 :
185- mul_mat_vec_q<type, 4 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
248+ {
249+ constexpr int c_ncols_y = 4 ;
250+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
251+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
186252 break ;
253+ }
187254 case 5 :
188- mul_mat_vec_q<type, 5 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
255+ {
256+ constexpr int c_ncols_y = 5 ;
257+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
258+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
189259 break ;
260+ }
190261 case 6 :
191- mul_mat_vec_q<type, 6 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
262+ {
263+ constexpr int c_ncols_y = 6 ;
264+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
265+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
192266 break ;
267+ }
193268 case 7 :
194- mul_mat_vec_q<type, 7 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
269+ {
270+ constexpr int c_ncols_y = 7 ;
271+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
272+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
195273 break ;
274+ }
196275 case 8 :
197- mul_mat_vec_q<type, 8 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
276+ {
277+ constexpr int c_ncols_y = 8 ;
278+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
279+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
198280 break ;
281+ }
199282 default :
200283 GGML_ABORT (" fatal error" );
201284 break ;
0 commit comments