Skip to content

Commit 60ce04c

Browse files
fix mul_mat_id
1 parent 2d24a9c commit 60ce04c

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

ggml/src/ggml-cuda/mmv.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,29 +344,30 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
344344
const int64_t s3 = dst->nb[3] / ts_dst;
345345

346346
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
347+
const int64_t ncols_dst = ids ? ne2 : ne1;
347348
const int64_t nchannels_y = ids ? ne11 : ne12;
348349
const int64_t nchannels_dst = ids ? ne1 : ne2;
349350
const int64_t stride_channel_dst = ids ? s1 : s2;
350351
const int64_t stride_channel_y = ids ? s11 : s12;
351352

352-
GGML_ASSERT(!ids || ne1 == 1);
353+
GGML_ASSERT(!ids || ncols_dst == 1);
353354

354355
switch (src0->type) {
355356
case GGML_TYPE_F32: {
356357
const float * src0_d = (const float *) src0->data;
357-
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ne1, s01, s11, s1,
358+
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
358359
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
359360
ne03, ne3, s03, s13, s3, prec, ctx.stream());
360361
} break;
361362
case GGML_TYPE_F16: {
362363
const half * src0_d = (const half *) src0->data;
363-
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ne1, s01, s11, s1,
364+
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
364365
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
365366
ne03, ne3, s03, s13, s3, prec, ctx.stream());
366367
} break;
367368
case GGML_TYPE_BF16: {
368369
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
369-
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ne1, s01, s11, s1,
370+
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
370371
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
371372
ne03, ne3, s03, s13, s3, prec, ctx.stream());
372373
} break;

0 commit comments

Comments
 (0)