1+ #include " ggml.h"
12#include " common.cuh"
23#include " mmv.cuh"
34
45template <typename T, typename type_acc, int block_size>
56static __global__ void mul_mat_vec (
67 const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
7- const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
8+ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
9+ const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
810 const int64_t row = blockIdx .x ;
9- const int64_t channel = blockIdx .z ;
11+ const int64_t channel = blockIdx .y ;
12+ const int64_t sample = blockIdx .z ;
1013 const int tid = threadIdx .x ;
1114 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
1215
13- x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
14- y += channel *stride_channel_y;
15- dst += channel *stride_channel_dst;
16+ x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row;
17+ y += sample *stride_sample_y + channel *stride_channel_y;
18+ dst += sample *stride_sample_dst + channel *stride_channel_dst;
1619
1720 const float2 * y2 = (const float2 *) y;
1821
@@ -91,12 +94,15 @@ template <typename T, typename type_acc>
9194static void launch_mul_mat_vec_cuda (
9295 const T * x, const float * y, float * dst,
9396 const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
94- const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
97+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
98+ const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
9599 cudaStream_t stream) {
96100 GGML_ASSERT (ncols % 2 == 0 );
97101 GGML_ASSERT (stride_row % 2 == 0 );
98102 GGML_ASSERT (nchannels_y % nchannels_x == 0 );
103+ GGML_ASSERT (nsamples_y % nsamples_x == 0 );
99104 const int64_t channel_ratio = nchannels_y / nchannels_x;
105+ const int64_t sample_ratio = nsamples_y / nsamples_x;
100106 int device;
101107 int warp_size;
102108
@@ -118,40 +124,48 @@ static void launch_mul_mat_vec_cuda(
118124 }
119125
120126 const int smem = warp_size*sizeof (float );
121- const dim3 block_nums (nrows, 1 , nchannels_y );
127+ const dim3 block_nums (nrows, nchannels_y, nsamples_y );
122128 const dim3 block_dims (block_size_best, 1 , 1 );
123129 switch (block_size_best) {
124130 case 32 : {
125131 mul_mat_vec<T, type_acc, 32 ><<<block_nums, block_dims, smem, stream>>>
126- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
132+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
133+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
127134 } break ;
128135 case 64 : {
129136 mul_mat_vec<T, type_acc, 64 ><<<block_nums, block_dims, smem, stream>>>
130- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
137+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
138+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
131139 } break ;
132140 case 96 : {
133141 mul_mat_vec<T, type_acc, 96 ><<<block_nums, block_dims, smem, stream>>>
134- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
142+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
143+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
135144 } break ;
136145 case 128 : {
137146 mul_mat_vec<T, type_acc, 128 ><<<block_nums, block_dims, smem, stream>>>
138- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
147+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
148+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
139149 } break ;
140150 case 160 : {
141151 mul_mat_vec<T, type_acc, 160 ><<<block_nums, block_dims, smem, stream>>>
142- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
152+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
153+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
143154 } break ;
144155 case 192 : {
145156 mul_mat_vec<T, type_acc, 192 ><<<block_nums, block_dims, smem, stream>>>
146- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
157+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
158+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
147159 } break ;
148160 case 224 : {
149161 mul_mat_vec<T, type_acc, 224 ><<<block_nums, block_dims, smem, stream>>>
150- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
162+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
163+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
151164 } break ;
152165 case 256 : {
153166 mul_mat_vec<T, type_acc, 256 ><<<block_nums, block_dims, smem, stream>>>
154- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
167+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
168+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
155169 } break ;
156170 default : {
157171 GGML_ABORT (" fatal error" );
@@ -163,16 +177,19 @@ template<typename T>
163177static void mul_mat_vec_cuda (
164178 const T * x, const float * y, float * dst,
165179 const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
166- const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
180+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
181+ const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
167182 enum ggml_prec prec, cudaStream_t stream) {
168183 switch (prec) {
169184 case GGML_PREC_DEFAULT: {
170- launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
171- stride_channel_x, stride_channel_y, stride_channel_dst, stream);
185+ launch_mul_mat_vec_cuda<T, half>
186+ (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
187+ nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
172188 } break ;
173189 case GGML_PREC_F32: {
174- launch_mul_mat_vec_cuda<T, float >(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
175- stride_channel_x, stride_channel_y, stride_channel_dst, stream);
190+ launch_mul_mat_vec_cuda<T, float >
191+ (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
192+ nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
176193 } break ;
177194 }
178195}
@@ -181,40 +198,42 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
181198 GGML_ASSERT (src1->type == GGML_TYPE_F32);
182199 GGML_ASSERT (dst->type == GGML_TYPE_F32);
183200
184- const int64_t ne00 = src0->ne [0 ];
185- const int64_t ne01 = src0->ne [1 ];
201+ GGML_TENSOR_BINARY_OP_LOCALS;
202+
203+ const size_t ts_src0 = ggml_type_size (src0->type );
204+ const size_t ts_src1 = ggml_type_size (src1->type );
205+ const size_t ts_dst = ggml_type_size (dst->type );
206+
207+ GGML_ASSERT (ne11 == 1 );
208+ GGML_ASSERT (ne12 == ne2);
209+ GGML_ASSERT (ne13 == ne3);
186210
187- GGML_ASSERT (src1->ne [1 ] == 1 );
211+ GGML_ASSERT (nb00 == ts_src0);
212+ GGML_ASSERT (nb10 == ts_src1);
213+ GGML_ASSERT (nb0 == ts_dst);
188214
189215 const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
190216 const enum ggml_prec prec = fast_fp16_available (cc) ? ggml_prec (dst->op_params [0 ]) : GGML_PREC_F32;
191217
192218 const float * src1_d = (const float *) src1->data ;
193219 float * dst_d = (float *) dst->data ;
194220
195- const int64_t ne02 = src0->ne [2 ];
196- const int64_t ne12 = src1->ne [2 ];
197- GGML_ASSERT (dst->ne [2 ] == ne12);
198-
199- GGML_ASSERT (src0->ne [3 ] == 1 );
200- GGML_ASSERT (src1->ne [3 ] == 1 );
201- GGML_ASSERT ( dst->ne [3 ] == 1 );
202-
203- const int64_t stride_row = src0->nb [1 ] / ggml_type_size (src0->type );
204- const int64_t channel_stride_x = src0->nb [2 ] / ggml_type_size (src0->type );
205- const int64_t channel_stride_y = src1->nb [2 ] / ggml_type_size (src1->type );
206- const int64_t channel_stride_dst = dst->nb [2 ] / ggml_type_size ( dst->type );
221+ const int64_t s01 = src0->nb [1 ] / ts_src0;
222+ const int64_t s02 = src0->nb [2 ] / ts_src0;
223+ const int64_t s12 = src1->nb [2 ] / ts_src1;
224+ const int64_t s2 = dst->nb [2 ] / ts_dst;
225+ const int64_t s03 = src0->nb [3 ] / ts_src0;
226+ const int64_t s13 = src1->nb [3 ] / ts_src1;
227+ const int64_t s3 = dst->nb [3 ] / ts_dst;
207228
208229 switch (src0->type ) {
209230 case GGML_TYPE_F16: {
210231 const half * src0_d = (const half *) src0->data ;
211- mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
212- channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream ());
232+ mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream ());
213233 } break ;
214234 case GGML_TYPE_BF16: {
215235 const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data ;
216- mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
217- channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream ());
236+ mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream ());
218237 } break ;
219238 default :
220239 GGML_ABORT (" unsupported type: %s" , ggml_type_name (src0->type ));
@@ -243,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec(
243262 const int64_t stride_row = ne00;
244263 const int64_t nchannels_x = 1 ;
245264 const int64_t nchannels_y = 1 ;
246- const int64_t channel_stride_x = 0 ;
247- const int64_t channel_stride_y = 0 ;
248- const int64_t channel_stride_dst = 0 ;
265+ const int64_t stride_channel_x = 0 ;
266+ const int64_t stride_channel_y = 0 ;
267+ const int64_t stride_channel_dst = 0 ;
268+ const int64_t nsamples_x = 1 ;
269+ const int64_t nsamples_y = 1 ;
270+ const int64_t stride_sample_x = 0 ;
271+ const int64_t stride_sample_y = 0 ;
272+ const int64_t stride_sample_dst = 0 ;
249273
250274 switch (src0->type ) {
251275 case GGML_TYPE_F16: {
252276 const half * src0_d = (const half *) src0_dd_i;
253277 mul_mat_vec_cuda (src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
254- nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
278+ nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
279+ nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
255280 } break ;
256281 case GGML_TYPE_BF16: {
257282 const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
258283 mul_mat_vec_cuda (src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
259- nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
284+ nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
285+ nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
260286 } break ;
261287 default :
262288 GGML_ABORT (" unsupported type: %s" , ggml_type_name (src0->type ));
0 commit comments