@@ -344,29 +344,30 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
344344 const int64_t s3 = dst->nb [3 ] / ts_dst;
345345
346346 // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
347+ const int64_t ncols_dst = ids ? ne2 : ne1;
347348 const int64_t nchannels_y = ids ? ne11 : ne12;
348349 const int64_t nchannels_dst = ids ? ne1 : ne2;
349350 const int64_t stride_channel_dst = ids ? s1 : s2;
350351 const int64_t stride_channel_y = ids ? s11 : s12;
351352
352- GGML_ASSERT (!ids || ne1 == 1 );
353+ GGML_ASSERT (!ids || ncols_dst == 1 );
353354
354355 switch (src0->type ) {
355356 case GGML_TYPE_F32: {
356357 const float * src0_d = (const float *) src0->data ;
357- mul_mat_vec_cuda (src0_d, src1_d, ids_d, dst_d, ne00, ne01, ne1 , s01, s11, s1,
358+ mul_mat_vec_cuda (src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst , s01, s11, s1,
358359 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
359360 ne03, ne3, s03, s13, s3, prec, ctx.stream ());
360361 } break ;
361362 case GGML_TYPE_F16: {
362363 const half * src0_d = (const half *) src0->data ;
363- mul_mat_vec_cuda (src0_d, src1_d, ids_d, dst_d, ne00, ne01, ne1 , s01, s11, s1,
364+ mul_mat_vec_cuda (src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst , s01, s11, s1,
364365 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
365366 ne03, ne3, s03, s13, s3, prec, ctx.stream ());
366367 } break ;
367368 case GGML_TYPE_BF16: {
368369 const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data ;
369- mul_mat_vec_cuda (src0_d, src1_d, ids_d, dst_d, ne00, ne01, ne1 , s01, s11, s1,
370+ mul_mat_vec_cuda (src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst , s01, s11, s1,
370371 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
371372 ne03, ne3, s03, s13, s3, prec, ctx.stream ());
372373 } break ;
0 commit comments