55#include < c10/cuda/CUDAGuard.h>
66
77#include " cuda_compat.h"
8+ #include " dispatch_utils.h"
89
910#include " ggml-common.h"
1011#include " vecdotq.cuh"
1314#include " mmq.cuh"
1415
1516// Q8 gemv
16- static __global__ void quantize_q8_1 (const half* __restrict__ x,
17+ template <typename scalar_t >
18+ static __global__ void quantize_q8_1 (const scalar_t * __restrict__ x,
1719 void * __restrict__ vy, const int kx,
1820 const int kx_padded) {
1921 const int ix = blockDim .x * blockIdx .x + threadIdx .x ;
@@ -28,7 +30,7 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
2830 const int ib = i_padded / QK8_1; // block index
2931 const int iqs = i_padded % QK8_1; // quant index
3032
31- const float xi = ix < kx ? __half2float (x[iy * kx + ix]) : 0 .0f ;
33+ const float xi = ix < kx ? static_cast < float > (x[iy * kx + ix]) : 0 .0f ;
3234 float amax = fabsf (xi);
3335 float sum = xi;
3436
@@ -51,14 +53,16 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
5153 y[ib].ds .y = __float2half (sum);
5254}
5355
54- static void quantize_row_q8_1_cuda (const half* x, void * vy, const int kx,
56+ template <typename scalar_t >
57+ static void quantize_row_q8_1_cuda (const scalar_t * x, void * vy, const int kx,
5558 const int ky, cudaStream_t stream) {
5659 const int64_t kx_padded = (kx + 512 - 1 ) / 512 * 512 ;
5760 const int block_num_x =
5861 (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1 ) / CUDA_QUANTIZE_BLOCK_SIZE;
5962 const dim3 num_blocks (block_num_x, ky, 1 );
6063 const dim3 block_size (CUDA_DEQUANTIZE_BLOCK_SIZE, 1 , 1 );
61- quantize_q8_1<<<num_blocks, block_size, 0 , stream>>> (x, vy, kx, kx_padded);
64+ quantize_q8_1<scalar_t >
65+ <<<num_blocks, block_size, 0 , stream>>> (x, vy, kx, kx_padded);
6266}
6367
6468torch::Tensor ggml_dequantize (torch::Tensor W, // quant weight
@@ -79,101 +83,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
7983 int col = X.sizes ()[1 ];
8084 const int padded = (col + 512 - 1 ) / 512 * 512 ;
8185 const at::cuda::OptionalCUDAGuard device_guard (device_of (X));
82- auto options =
83- torch::TensorOptions ().dtype (torch::kFloat16 ).device (W.device ());
86+ auto options = torch::TensorOptions ().dtype (X.dtype ()).device (W.device ());
8487 at::Tensor Y = torch::empty ({1 , row}, options);
8588 cudaStream_t stream = at::cuda::getCurrentCUDAStream ().stream ();
8689 options = torch::TensorOptions ().dtype (torch::kInt32 ).device (W.device ());
8790 at::Tensor quant_X = torch::empty ({1 , padded / 32 * 9 }, options);
88- quantize_row_q8_1_cuda ((half*)X.data_ptr (), (void *)quant_X.data_ptr (), col, 1 ,
89- stream);
90- switch (type) {
91- case 2 :
92- mul_mat_vec_q4_0_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
93- (half*)Y.data_ptr (), col, row, stream);
94- break ;
95- case 3 :
96- mul_mat_vec_q4_1_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
97- (half*)Y.data_ptr (), col, row, stream);
98- break ;
99- case 6 :
100- mul_mat_vec_q5_0_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
101- (half*)Y.data_ptr (), col, row, stream);
102- break ;
103- case 7 :
104- mul_mat_vec_q5_1_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
105- (half*)Y.data_ptr (), col, row, stream);
106- break ;
107- case 8 :
108- mul_mat_vec_q8_0_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
109- (half*)Y.data_ptr (), col, row, stream);
110- break ;
111- case 10 :
112- mul_mat_vec_q2_K_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
113- (half*)Y.data_ptr (), col, row, stream);
114- break ;
115- case 11 :
116- mul_mat_vec_q3_K_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
117- (half*)Y.data_ptr (), col, row, stream);
118- break ;
119- case 12 :
120- mul_mat_vec_q4_K_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
121- (half*)Y.data_ptr (), col, row, stream);
122- break ;
123- case 13 :
124- mul_mat_vec_q5_K_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
125- (half*)Y.data_ptr (), col, row, stream);
126- break ;
127- case 14 :
128- mul_mat_vec_q6_K_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
129- (half*)Y.data_ptr (), col, row, stream);
130- break ;
131- case 16 :
132- mul_mat_vec_iq2_xxs_q8_1_cuda ((void *)W.data_ptr (),
133- (void *)quant_X.data_ptr (),
134- (half*)Y.data_ptr (), col, row, stream);
135- break ;
136- case 17 :
137- mul_mat_vec_iq2_xs_q8_1_cuda ((void *)W.data_ptr (),
138- (void *)quant_X.data_ptr (),
139- (half*)Y.data_ptr (), col, row, stream);
140- break ;
141- case 18 :
142- mul_mat_vec_iq3_xxs_q8_1_cuda ((void *)W.data_ptr (),
143- (void *)quant_X.data_ptr (),
144- (half*)Y.data_ptr (), col, row, stream);
145- break ;
146- case 19 :
147- mul_mat_vec_iq1_s_q8_1_cuda ((void *)W.data_ptr (),
148- (void *)quant_X.data_ptr (),
149- (half*)Y.data_ptr (), col, row, stream);
150- break ;
151- case 20 :
152- mul_mat_vec_iq4_nl_q8_1_cuda ((void *)W.data_ptr (),
153- (void *)quant_X.data_ptr (),
154- (half*)Y.data_ptr (), col, row, stream);
155- break ;
156- case 21 :
157- mul_mat_vec_iq3_s_q8_1_cuda ((void *)W.data_ptr (),
158- (void *)quant_X.data_ptr (),
159- (half*)Y.data_ptr (), col, row, stream);
160- break ;
161- case 22 :
162- mul_mat_vec_iq2_s_q8_1_cuda ((void *)W.data_ptr (),
163- (void *)quant_X.data_ptr (),
164- (half*)Y.data_ptr (), col, row, stream);
165- break ;
166- case 23 :
167- mul_mat_vec_iq4_xs_q8_1_cuda ((void *)W.data_ptr (),
168- (void *)quant_X.data_ptr (),
169- (half*)Y.data_ptr (), col, row, stream);
170- break ;
171- case 29 :
172- mul_mat_vec_iq1_m_q8_1_cuda ((void *)W.data_ptr (),
173- (void *)quant_X.data_ptr (),
174- (half*)Y.data_ptr (), col, row, stream);
175- break ;
176- }
91+ VLLM_DISPATCH_FLOATING_TYPES (X.scalar_type (), " ggml_mul_mat_vec_a8" , [&] {
92+ quantize_row_q8_1_cuda<scalar_t >((scalar_t *)X.data_ptr (),
93+ (void *)quant_X.data_ptr (), col, 1 , stream);
94+ switch (type) {
95+ case 2 :
96+ mul_mat_vec_q4_0_q8_1_cuda<scalar_t >(
97+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
98+ (scalar_t *)Y.data_ptr (), col, row, stream);
99+ break ;
100+ case 3 :
101+ mul_mat_vec_q4_1_q8_1_cuda<scalar_t >(
102+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
103+ (scalar_t *)Y.data_ptr (), col, row, stream);
104+ break ;
105+ case 6 :
106+ mul_mat_vec_q5_0_q8_1_cuda<scalar_t >(
107+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
108+ (scalar_t *)Y.data_ptr (), col, row, stream);
109+ break ;
110+ case 7 :
111+ mul_mat_vec_q5_1_q8_1_cuda<scalar_t >(
112+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
113+ (scalar_t *)Y.data_ptr (), col, row, stream);
114+ break ;
115+ case 8 :
116+ mul_mat_vec_q8_0_q8_1_cuda<scalar_t >(
117+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
118+ (scalar_t *)Y.data_ptr (), col, row, stream);
119+ break ;
120+ case 10 :
121+ mul_mat_vec_q2_K_q8_1_cuda<scalar_t >(
122+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
123+ (scalar_t *)Y.data_ptr (), col, row, stream);
124+ break ;
125+ case 11 :
126+ mul_mat_vec_q3_K_q8_1_cuda<scalar_t >(
127+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
128+ (scalar_t *)Y.data_ptr (), col, row, stream);
129+ break ;
130+ case 12 :
131+ mul_mat_vec_q4_K_q8_1_cuda<scalar_t >(
132+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
133+ (scalar_t *)Y.data_ptr (), col, row, stream);
134+ break ;
135+ case 13 :
136+ mul_mat_vec_q5_K_q8_1_cuda<scalar_t >(
137+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
138+ (scalar_t *)Y.data_ptr (), col, row, stream);
139+ break ;
140+ case 14 :
141+ mul_mat_vec_q6_K_q8_1_cuda<scalar_t >(
142+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
143+ (scalar_t *)Y.data_ptr (), col, row, stream);
144+ break ;
145+ case 16 :
146+ mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t >(
147+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
148+ (scalar_t *)Y.data_ptr (), col, row, stream);
149+ break ;
150+ case 17 :
151+ mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t >(
152+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
153+ (scalar_t *)Y.data_ptr (), col, row, stream);
154+ break ;
155+ case 18 :
156+ mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t >(
157+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
158+ (scalar_t *)Y.data_ptr (), col, row, stream);
159+ break ;
160+ case 19 :
161+ mul_mat_vec_iq1_s_q8_1_cuda<scalar_t >(
162+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
163+ (scalar_t *)Y.data_ptr (), col, row, stream);
164+ break ;
165+ case 20 :
166+ mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t >(
167+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
168+ (scalar_t *)Y.data_ptr (), col, row, stream);
169+ break ;
170+ case 21 :
171+ mul_mat_vec_iq3_s_q8_1_cuda<scalar_t >(
172+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
173+ (scalar_t *)Y.data_ptr (), col, row, stream);
174+ break ;
175+ case 22 :
176+ mul_mat_vec_iq2_s_q8_1_cuda<scalar_t >(
177+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
178+ (scalar_t *)Y.data_ptr (), col, row, stream);
179+ break ;
180+ case 23 :
181+ mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t >(
182+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
183+ (scalar_t *)Y.data_ptr (), col, row, stream);
184+ break ;
185+ case 29 :
186+ mul_mat_vec_iq1_m_q8_1_cuda<scalar_t >(
187+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
188+ (scalar_t *)Y.data_ptr (), col, row, stream);
189+ break ;
190+ }
191+ });
177192 return Y;
178193}
179194
@@ -184,66 +199,67 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
184199 int padded = (col + 512 - 1 ) / 512 * 512 ;
185200 int batch = X.sizes ()[0 ];
186201 const at::cuda::OptionalCUDAGuard device_guard (device_of (X));
187- auto options =
188- torch::TensorOptions ().dtype (torch::kFloat16 ).device (W.device ());
202+ auto options = torch::TensorOptions ().dtype (X.dtype ()).device (W.device ());
189203 at::Tensor Y = torch::empty ({batch, row}, options);
190204 cudaStream_t stream = at::cuda::getCurrentCUDAStream ().stream ();
191205 options = torch::TensorOptions ().dtype (torch::kInt32 ).device (W.device ());
192206 at::Tensor quant_X = torch::empty ({batch, padded / 32 * 9 }, options);
193- quantize_row_q8_1_cuda ((half*)X.data_ptr (), (void *)quant_X.data_ptr (), col,
194- batch, stream);
195-
196- switch (type) {
197- case 2 :
198- ggml_mul_mat_q4_0_q8_1_cuda (
199- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
200- col, row, batch, padded, row, stream);
201- break ;
202- case 3 :
203- ggml_mul_mat_q4_1_q8_1_cuda (
204- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
205- col, row, batch, padded, row, stream);
206- break ;
207- case 6 :
208- ggml_mul_mat_q5_0_q8_1_cuda (
209- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
210- col, row, batch, padded, row, stream);
211- break ;
212- case 7 :
213- ggml_mul_mat_q5_1_q8_1_cuda (
214- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
215- col, row, batch, padded, row, stream);
216- break ;
217- case 8 :
218- ggml_mul_mat_q8_0_q8_1_cuda (
219- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
220- col, row, batch, padded, row, stream);
221- break ;
222- case 10 :
223- ggml_mul_mat_q2_K_q8_1_cuda (
224- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
225- col, row, batch, padded, row, stream);
226- break ;
227- case 11 :
228- ggml_mul_mat_q3_K_q8_1_cuda (
229- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
230- col, row, batch, padded, row, stream);
231- break ;
232- case 12 :
233- ggml_mul_mat_q4_K_q8_1_cuda (
234- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
235- col, row, batch, padded, row, stream);
236- break ;
237- case 13 :
238- ggml_mul_mat_q5_K_q8_1_cuda (
239- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
240- col, row, batch, padded, row, stream);
241- break ;
242- case 14 :
243- ggml_mul_mat_q6_K_q8_1_cuda (
244- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
245- col, row, batch, padded, row, stream);
246- break ;
247- }
207+ VLLM_DISPATCH_FLOATING_TYPES (X.scalar_type (), " ggml_mul_mat_a8" , [&] {
208+ quantize_row_q8_1_cuda ((scalar_t *)X.data_ptr (), (void *)quant_X.data_ptr (),
209+ col, batch, stream);
210+
211+ switch (type) {
212+ case 2 :
213+ ggml_mul_mat_q4_0_q8_1_cuda (
214+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
215+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
216+ break ;
217+ case 3 :
218+ ggml_mul_mat_q4_1_q8_1_cuda (
219+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
220+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
221+ break ;
222+ case 6 :
223+ ggml_mul_mat_q5_0_q8_1_cuda (
224+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
225+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
226+ break ;
227+ case 7 :
228+ ggml_mul_mat_q5_1_q8_1_cuda (
229+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
230+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
231+ break ;
232+ case 8 :
233+ ggml_mul_mat_q8_0_q8_1_cuda (
234+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
235+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
236+ break ;
237+ case 10 :
238+ ggml_mul_mat_q2_K_q8_1_cuda (
239+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
240+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
241+ break ;
242+ case 11 :
243+ ggml_mul_mat_q3_K_q8_1_cuda (
244+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
245+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
246+ break ;
247+ case 12 :
248+ ggml_mul_mat_q4_K_q8_1_cuda (
249+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
250+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
251+ break ;
252+ case 13 :
253+ ggml_mul_mat_q5_K_q8_1_cuda (
254+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
255+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
256+ break ;
257+ case 14 :
258+ ggml_mul_mat_q6_K_q8_1_cuda (
259+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
260+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
261+ break ;
262+ }
263+ });
248264 return Y;
249265}
0 commit comments