@@ -47,36 +47,93 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
4747 1 ;
4848}
4949
50+ static constexpr __device__ int get_device_table ()
51+ {
52+ #if defined(RDNA2) || defined(RDNA3)
53+ return 1 ;
54+ #else
55+ return 0 ;
56+ #endif // defined(RDNA2) || defined(RDNA3)
57+ }
58+
59+ static __host__ int get_device_table (int cc)
60+ {
61+ if (GGML_CUDA_CC_IS_RDNA2 (cc) || GGML_CUDA_CC_IS_RDNA3 (cc)) {
62+ return 1 ;
63+ }
64+
65+ return 0 ;
66+ }
67+
68+ static constexpr int calc_nwarps (int ncols_y, int table_id)
69+ {
70+ if (table_id == 0 )
71+ {
72+ switch (ncols_y) {
73+ case 1 :
74+ case 2 :
75+ case 3 :
76+ case 4 :
77+ return 2 ;
78+ case 5 :
79+ case 6 :
80+ case 7 :
81+ case 8 :
82+ return 4 ;
83+ default :
84+ return 1 ;
85+ }
86+ } else {
87+ return 1 ;
88+ }
89+ }
90+
91+ static constexpr int calc_rows_per_block (int ncols_y, int table_id)
92+ {
93+ if (table_id == 0 ) {
94+ switch (ncols_y) {
95+ case 1 :
96+ return 1 ;
97+ case 2 :
98+ case 3 :
99+ case 4 :
100+ case 5 :
101+ case 6 :
102+ case 7 :
103+ case 8 :
104+ return 2 ;
105+ default :
106+ return 1 ;
107+ }
108+ } else {
109+ return 1 ;
110+ }
111+ }
112+
50113template <ggml_type type, int ncols_y>
51- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52114// tell the compiler to use as many registers as it wants, see nwarps definition below
53- __launch_bounds__ ((ncols_y <= 4 ? 4 : 2 )*WARP_SIZE, 1)
54- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
115+ __launch_bounds__ (calc_nwarps(ncols_y, get_device_table())*ggml_cuda_get_physical_warp_size(), 4)
55116static __global__ void mul_mat_vec_q(
56117 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57118 const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
58119
59120 constexpr int qk = ggml_cuda_type_traits<type>::qk;
60121 constexpr int qi = ggml_cuda_type_traits<type>::qi;
61122 constexpr int vdr = get_vdr_mmvq (type);
123+ constexpr int table_id = get_device_table ();
124+ constexpr int nwarps = calc_nwarps (ncols_y, table_id);
125+ constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_y, table_id);
126+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
62127
63128 constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda (type);
64129
65- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66- constexpr int nwarps = 1 ;
67- constexpr int rows_per_cuda_block = 1 ;
68- #else
69- constexpr int nwarps = ncols_y <= 4 ? 4 : 2 ;
70- constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2 ;
71- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72-
73- const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
130+ const int tid = warp_size*threadIdx .y + threadIdx .x ;
74131 const int row0 = rows_per_cuda_block*blockIdx .x ;
75132 const int blocks_per_row_x = ncols_x / qk;
76133 const int blocks_per_col_y = nrows_y / QK8_1;
77- constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
134+ constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
78135
79- // partial sum for each thread
136+ // partial sum for each thread
80137 float tmp[ncols_y][rows_per_cuda_block] = {0 .0f };
81138
82139 const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +153,7 @@ static __global__ void mul_mat_vec_q(
96153 }
97154 }
98155
99- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][WARP_SIZE ];
156+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][warp_size ];
100157 if (threadIdx .y > 0 ) {
101158#pragma unroll
102159 for (int j = 0 ; j < ncols_y; ++j) {
@@ -120,7 +177,7 @@ static __global__ void mul_mat_vec_q(
120177 for (int l = 0 ; l < nwarps-1 ; ++l) {
121178 tmp[j][i] += tmp_shared[l][j][i][threadIdx .x ];
122179 }
123- tmp[j][i] = warp_reduce_sum (tmp[j][i]);
180+ tmp[j][i] = warp_reduce_sum<warp_size> (tmp[j][i]);
124181 }
125182
126183 if (threadIdx .x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx .x < nrows_dst)) {
@@ -129,73 +186,85 @@ static __global__ void mul_mat_vec_q(
129186 }
130187}
131188
189+ static std::pair<dim3 , dim3 > calc_launch_params (const int ncols_y, const int nrows_x, const int warp_size, int table_id)
190+ {
191+ const int64_t nblocks = (nrows_x + calc_rows_per_block (ncols_y, table_id) - 1 ) / calc_rows_per_block (ncols_y, table_id);
192+ const dim3 block_nums (nblocks, 1 , 1 );
193+ const dim3 block_dims (warp_size, calc_nwarps (ncols_y, table_id), 1 );
194+ return {block_nums, block_dims};
195+ }
196+
132197template <ggml_type type>
133198static void mul_mat_vec_q_cuda (
134199 const void * vx, const void * vy, float * dst,
135200 const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
201+ int device;
202+ int warp_size;
136203
137204 GGML_ASSERT (ncols_x % ggml_blck_size (type) == 0 );
138205 GGML_ASSERT (ncols_y <= MMVQ_MAX_BATCH_SIZE);
139206
140- int id = ggml_cuda_get_device ();
141-
142- int64_t nwarps = 1 ;
143- int64_t rows_per_cuda_block = 1 ;
144-
145- if (ggml_cuda_info ().devices [id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146- switch (ncols_y) {
147- case 1 :
148- nwarps = 4 ;
149- rows_per_cuda_block = 1 ;
150- break ;
151- case 2 :
152- case 3 :
153- case 4 :
154- nwarps = 4 ;
155- rows_per_cuda_block = 2 ;
156- break ;
157- case 5 :
158- case 6 :
159- case 7 :
160- case 8 :
161- nwarps = 2 ;
162- rows_per_cuda_block = 2 ;
163- break ;
164- default :
165- GGML_ABORT (" fatal error" );
166- break ;
167- }
168- }
169-
170- const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1 ) / rows_per_cuda_block;
171- const dim3 block_nums (nblocks, 1 , 1 );
172- const dim3 block_dims (WARP_SIZE, nwarps, 1 );
207+ CUDA_CHECK (cudaGetDevice (&device));
208+ warp_size = ggml_cuda_info ().devices [device].warp_size ;
209+ int table_id = get_device_table (ggml_cuda_info ().devices [device].cc );
173210
174211 switch (ncols_y) {
175212 case 1 :
176- mul_mat_vec_q<type, 1 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
213+ {
214+ constexpr int c_ncols_y = 1 ;
215+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
216+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
177217 break ;
218+ }
178219 case 2 :
179- mul_mat_vec_q<type, 2 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
220+ {
221+ constexpr int c_ncols_y = 2 ;
222+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
223+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
180224 break ;
225+ }
181226 case 3 :
182- mul_mat_vec_q<type, 3 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
227+ {
228+ constexpr int c_ncols_y = 3 ;
229+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
230+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
183231 break ;
232+ }
184233 case 4 :
185- mul_mat_vec_q<type, 4 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
234+ {
235+ constexpr int c_ncols_y = 4 ;
236+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
237+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
186238 break ;
239+ }
187240 case 5 :
188- mul_mat_vec_q<type, 5 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
241+ {
242+ constexpr int c_ncols_y = 5 ;
243+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
244+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
189245 break ;
246+ }
190247 case 6 :
191- mul_mat_vec_q<type, 6 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
248+ {
249+ constexpr int c_ncols_y = 6 ;
250+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
251+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
192252 break ;
253+ }
193254 case 7 :
194- mul_mat_vec_q<type, 7 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
255+ {
256+ constexpr int c_ncols_y = 7 ;
257+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
258+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
195259 break ;
260+ }
196261 case 8 :
197- mul_mat_vec_q<type, 8 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
262+ {
263+ constexpr int c_ncols_y = 8 ;
264+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
265+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
198266 break ;
267+ }
199268 default :
200269 GGML_ABORT (" fatal error" );
201270 break ;
0 commit comments