@@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst>
141141__launch_bounds__ (calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142142static __global__ void mul_mat_vec_q(
143143 const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
144- const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
145- const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
146- const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
144+ const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145+ const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146+ const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147+ const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
147148
148149 constexpr int qk = ggml_cuda_type_traits<type>::qk;
149150 constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q(
161162 constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
162163
163164 // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
164- const int channel_dst = blockIdx .y ;
165- const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
166- const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
167- const int sample_dst = blockIdx .z ;
168- const int sample_x = sample_dst / sample_ratio;
169- const int sample_y = sample_dst;
165+ const uint32_t channel_dst = blockIdx .y ;
166+ const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv ( channel_dst, channel_ratio) ;
167+ const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo ( channel_dst, nchannels_y) : channel_dst;
168+ const uint32_t sample_dst = blockIdx .z ;
169+ const uint32_t sample_x = fastdiv ( sample_dst, sample_ratio) ;
170+ const uint32_t sample_y = sample_dst;
170171
171172 // partial sum for each thread
172173 float tmp[ncols_dst][rows_per_cuda_block] = {{0 .0f }};
@@ -247,95 +248,80 @@ static void mul_mat_vec_q_switch_ncols_dst(
247248 GGML_ASSERT (ncols_x % ggml_blck_size (type) == 0 );
248249 GGML_ASSERT (ncols_dst <= MMVQ_MAX_BATCH_SIZE);
249250
250- const int channel_ratio = nchannels_dst / nchannels_x;
251- const int sample_ratio = nsamples_dst / nsamples_x;
251+ const uint3 nchannels_y_fd = ids ? init_fastdiv_values (nchannels_y) : make_uint3 (0 , 0 , 0 );
252+ const uint3 channel_ratio_fd = ids ? make_uint3 (0 , 0 , 0 ) : init_fastdiv_values (nchannels_dst / nchannels_x);
253+ const uint3 sample_ratio_fd = init_fastdiv_values (nsamples_dst / nsamples_x);
252254
253255 const int device = ggml_cuda_get_device ();
254256 const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
255257 const mmvq_parameter_table_id table_id = get_device_table_id (ggml_cuda_info ().devices [device].cc );
256258
257259 GGML_ASSERT (!ids || ncols_dst == 1 );
258260 switch (ncols_dst) {
259- case 1 :
260- {
261+ case 1 : {
261262 constexpr int c_ncols_dst = 1 ;
262263 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
263264 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
264- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
265- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
266- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
267- break ;
268- }
269- case 2 :
270- {
265+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
266+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
267+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
268+ } break ;
269+ case 2 : {
271270 constexpr int c_ncols_dst = 2 ;
272271 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
273272 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
274- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
275- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
276- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
277- break ;
278- }
279- case 3 :
280- {
273+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
274+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
275+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
276+ } break ;
277+ case 3 : {
281278 constexpr int c_ncols_dst = 3 ;
282279 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
283280 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
284- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
285- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
286- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
287- break ;
288- }
289- case 4 :
290- {
281+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
282+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
283+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
284+ } break ;
285+ case 4 : {
291286 constexpr int c_ncols_dst = 4 ;
292287 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
293288 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
294- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
295- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
296- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
297- break ;
298- }
299- case 5 :
300- {
289+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
290+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
291+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
292+ } break ;
293+ case 5 : {
301294 constexpr int c_ncols_dst = 5 ;
302295 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
303296 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
304- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
305- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
306- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
307- break ;
308- }
309- case 6 :
310- {
297+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
298+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
299+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
300+ } break ;
301+ case 6 : {
311302 constexpr int c_ncols_dst = 6 ;
312303 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
313304 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
314- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
315- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
316- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
317- break ;
318- }
319- case 7 :
320- {
305+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
306+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
307+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
308+ } break ;
309+ case 7 : {
321310 constexpr int c_ncols_dst = 7 ;
322311 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
323312 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
324- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
325- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
326- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
327- break ;
328- }
329- case 8 :
330- {
313+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
314+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
315+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
316+ } break ;
317+ case 8 : {
331318 constexpr int c_ncols_dst = 8 ;
332319 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
333320 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
334- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
335- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
336- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
337- break ;
338- }
321+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
322+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
323+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
324+ } break ;
339325 default :
340326 GGML_ABORT (" fatal error" );
341327 break ;
0 commit comments