@@ -47,36 +47,108 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
4747 1 ;
4848}
4949
50+ static constexpr __device__ int get_device_table_id ()
51+ {
52+ #if defined(RDNA2) || defined(RDNA3)
53+ return 2 ;
54+ #elif defined(GCN) || defined(CDNA)
55+ return 1 ;
56+ #else
57+ return 0 ;
58+ #endif
59+ }
60+
61+ static __host__ int get_device_table_id (int cc)
62+ {
63+ if (GGML_CUDA_CC_IS_RDNA2 (cc) || GGML_CUDA_CC_IS_RDNA3 (cc)) {
64+ return 2 ;
65+ }
66+ if (GGML_CUDA_CC_IS_GCN (cc) || GGML_CUDA_CC_IS_CDNA (cc)) {
67+ return 1 ;
68+ }
69+ return 0 ;
70+ }
71+
72+ static constexpr int calc_nwarps (int ncols_y, int table_id)
73+ {
74+ if (table_id == 0 ) {
75+ switch (ncols_y) {
76+ case 1 :
77+ case 2 :
78+ case 3 :
79+ case 4 :
80+ return 4 ;
81+ case 5 :
82+ case 6 :
83+ case 7 :
84+ case 8 :
85+ return 2 ;
86+ default :
87+ return 1 ;
88+ }
89+ } else if (table_id == 1 ) {
90+ switch (ncols_y) {
91+ case 1 :
92+ case 2 :
93+ case 3 :
94+ case 4 :
95+ return 2 ;
96+ case 5 :
97+ case 6 :
98+ case 7 :
99+ case 8 :
100+ default :
101+ return 1 ;
102+ }
103+ }
104+ return 1 ;
105+ }
106+
107+ static constexpr int calc_rows_per_block (int ncols_y, int table_id)
108+ {
109+ if (table_id == 0 || table_id == 1 ) {
110+ switch (ncols_y) {
111+ case 1 :
112+ return 1 ;
113+ case 2 :
114+ case 3 :
115+ case 4 :
116+ case 5 :
117+ case 6 :
118+ case 7 :
119+ case 8 :
120+ return 2 ;
121+ default :
122+ return 1 ;
123+ }
124+ }
125+ return 1 ;
126+ }
127+
50128template <ggml_type type, int ncols_y>
51- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52129// tell the compiler to use as many registers as it wants, see nwarps definition below
53- __launch_bounds__ ((ncols_y <= 4 ? 4 : 2 )*WARP_SIZE, 1)
54- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
130+ __launch_bounds__ (calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
55131static __global__ void mul_mat_vec_q(
56132 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57133 const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
58134
59135 constexpr int qk = ggml_cuda_type_traits<type>::qk;
60136 constexpr int qi = ggml_cuda_type_traits<type>::qi;
61137 constexpr int vdr = get_vdr_mmvq (type);
138+ constexpr int table_id = get_device_table_id ();
139+ constexpr int nwarps = calc_nwarps (ncols_y, table_id);
140+ constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_y, table_id);
141+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
62142
63143 constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda (type);
64144
65- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66- constexpr int nwarps = 1 ;
67- constexpr int rows_per_cuda_block = 1 ;
68- #else
69- constexpr int nwarps = ncols_y <= 4 ? 4 : 2 ;
70- constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2 ;
71- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72-
73- const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
145+ const int tid = warp_size*threadIdx .y + threadIdx .x ;
74146 const int row0 = rows_per_cuda_block*blockIdx .x ;
75147 const int blocks_per_row_x = ncols_x / qk;
76148 const int blocks_per_col_y = nrows_y / QK8_1;
77- constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
149+ constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
78150
79- // partial sum for each thread
151+ // partial sum for each thread
80152 float tmp[ncols_y][rows_per_cuda_block] = {0 .0f };
81153
82154 const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +168,7 @@ static __global__ void mul_mat_vec_q(
96168 }
97169 }
98170
99- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][WARP_SIZE ];
171+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][warp_size ];
100172 if (threadIdx .y > 0 ) {
101173#pragma unroll
102174 for (int j = 0 ; j < ncols_y; ++j) {
@@ -120,7 +192,7 @@ static __global__ void mul_mat_vec_q(
120192 for (int l = 0 ; l < nwarps-1 ; ++l) {
121193 tmp[j][i] += tmp_shared[l][j][i][threadIdx .x ];
122194 }
123- tmp[j][i] = warp_reduce_sum (tmp[j][i]);
195+ tmp[j][i] = warp_reduce_sum<warp_size> (tmp[j][i]);
124196 }
125197
126198 if (threadIdx .x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx .x < nrows_dst)) {
@@ -129,73 +201,85 @@ static __global__ void mul_mat_vec_q(
129201 }
130202}
131203
204+ static std::pair<dim3 , dim3 > calc_launch_params (const int ncols_y, const int nrows_x, const int warp_size, int table_id)
205+ {
206+ const int64_t nblocks = (nrows_x + calc_rows_per_block (ncols_y, table_id) - 1 ) / calc_rows_per_block (ncols_y, table_id);
207+ const dim3 block_nums (nblocks, 1 , 1 );
208+ const dim3 block_dims (warp_size, calc_nwarps (ncols_y, table_id), 1 );
209+ return {block_nums, block_dims};
210+ }
211+
132212template <ggml_type type>
133213static void mul_mat_vec_q_cuda (
134214 const void * vx, const void * vy, float * dst,
135215 const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
216+ int device;
217+ int warp_size;
136218
137219 GGML_ASSERT (ncols_x % ggml_blck_size (type) == 0 );
138220 GGML_ASSERT (ncols_y <= MMVQ_MAX_BATCH_SIZE);
139221
140- int id = ggml_cuda_get_device ();
141-
142- int64_t nwarps = 1 ;
143- int64_t rows_per_cuda_block = 1 ;
144-
145- if (ggml_cuda_info ().devices [id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146- switch (ncols_y) {
147- case 1 :
148- nwarps = 4 ;
149- rows_per_cuda_block = 1 ;
150- break ;
151- case 2 :
152- case 3 :
153- case 4 :
154- nwarps = 4 ;
155- rows_per_cuda_block = 2 ;
156- break ;
157- case 5 :
158- case 6 :
159- case 7 :
160- case 8 :
161- nwarps = 2 ;
162- rows_per_cuda_block = 2 ;
163- break ;
164- default :
165- GGML_ABORT (" fatal error" );
166- break ;
167- }
168- }
169-
170- const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1 ) / rows_per_cuda_block;
171- const dim3 block_nums (nblocks, 1 , 1 );
172- const dim3 block_dims (WARP_SIZE, nwarps, 1 );
222+ CUDA_CHECK (cudaGetDevice (&device));
223+ warp_size = ggml_cuda_info ().devices [device].warp_size ;
224+ int table_id = get_device_table_id (ggml_cuda_info ().devices [device].cc );
173225
174226 switch (ncols_y) {
175227 case 1 :
176- mul_mat_vec_q<type, 1 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
228+ {
229+ constexpr int c_ncols_y = 1 ;
230+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
231+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
177232 break ;
233+ }
178234 case 2 :
179- mul_mat_vec_q<type, 2 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
235+ {
236+ constexpr int c_ncols_y = 2 ;
237+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
238+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
180239 break ;
240+ }
181241 case 3 :
182- mul_mat_vec_q<type, 3 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
242+ {
243+ constexpr int c_ncols_y = 3 ;
244+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
245+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
183246 break ;
247+ }
184248 case 4 :
185- mul_mat_vec_q<type, 4 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
249+ {
250+ constexpr int c_ncols_y = 4 ;
251+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
252+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
186253 break ;
254+ }
187255 case 5 :
188- mul_mat_vec_q<type, 5 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
256+ {
257+ constexpr int c_ncols_y = 5 ;
258+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
259+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
189260 break ;
261+ }
190262 case 6 :
191- mul_mat_vec_q<type, 6 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
263+ {
264+ constexpr int c_ncols_y = 6 ;
265+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
266+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
192267 break ;
268+ }
193269 case 7 :
194- mul_mat_vec_q<type, 7 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
270+ {
271+ constexpr int c_ncols_y = 7 ;
272+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
273+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
195274 break ;
275+ }
196276 case 8 :
197- mul_mat_vec_q<type, 8 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
277+ {
278+ constexpr int c_ncols_y = 8 ;
279+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
280+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
198281 break ;
282+ }
199283 default :
200284 GGML_ABORT (" fatal error" );
201285 break ;
0 commit comments