11#include " ggml.h"
22#include " mmf.cuh"
3- #include " mmid.cuh"
4-
53
64void ggml_cuda_mul_mat_f (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
75 GGML_ASSERT ( src1->type == GGML_TYPE_F32);
@@ -39,12 +37,6 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
3937 const int64_t ids_s0 = ids ? ids->nb [0 ] / ggml_type_size (ids->type ) : 0 ;
4038 const int64_t ids_s1 = ids ? ids->nb [1 ] / ggml_type_size (ids->type ) : 0 ;
4139
42- mmf_ids_data ids_info{};
43- mmf_ids_data * ids_info_ptr = nullptr ;
44- ggml_cuda_pool_alloc<int32_t > ids_src_compact_dev;
45- ggml_cuda_pool_alloc<int32_t > ids_dst_compact_dev;
46- ggml_cuda_pool_alloc<int32_t > expert_bounds_dev;
47-
4840 // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
4941 const int64_t ncols_dst = ids ? ne2 : ne1;
5042 const int64_t nchannels_dst = ids ? ne1 : ne2;
@@ -62,57 +54,30 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
6254 nchannels_y = ids->ne [0 ];
6355 }
6456
65- if (ids && ncols_dst > 16 ) {
66- const int64_t n_expert_used = ids->ne [0 ];
67- const int64_t n_experts = ne02;
68- const int64_t n_tokens = ne12;
69- const int64_t ne_get_rows = n_tokens * n_expert_used;
70-
71- ids_src_compact_dev.alloc (ctx.pool (), ne_get_rows);
72- ids_dst_compact_dev.alloc (ctx.pool (), ne_get_rows);
73- expert_bounds_dev.alloc (ctx.pool (), n_experts + 1 );
74-
75- const int si1 = static_cast <int >(ids_s1);
76- const int sis1 = static_cast <int >(src1->nb [2 ] / src1->nb [1 ]);
77-
78- GGML_ASSERT (sis1 > 0 );
79-
80- ggml_cuda_launch_mm_ids_helper (ids_d, ids_src_compact_dev.get (), ids_dst_compact_dev.get (), expert_bounds_dev.get (),
81- static_cast <int >(n_experts), static_cast <int >(n_tokens), static_cast <int >(n_expert_used), static_cast <int >(ne11), si1, sis1, ctx.stream ());
82- CUDA_CHECK (cudaGetLastError ());
83-
84- ids_info.ids_src_compact = ids_src_compact_dev.get ();
85- ids_info.ids_dst_compact = ids_dst_compact_dev.get ();
86- ids_info.expert_bounds_dev = expert_bounds_dev.get ();
87- ids_info.n_experts = static_cast <int >(n_experts);
88- ids_info.sis1 = sis1;
89- ids_info_ptr = &ids_info;
90- }
91-
9257 switch (src0->type ) {
9358 case GGML_TYPE_F32: {
9459 const float * src0_d = (const float *) src0->data ;
9560 constexpr int vals_per_T = 1 ;
9661 mul_mat_f_switch_cols_per_block (
9762 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
9863 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
99- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream (), ids_info_ptr );
64+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream ());
10065 } break ;
10166 case GGML_TYPE_F16: {
10267 const half2 * src0_d = (const half2 *) src0->data ;
10368 constexpr int vals_per_T = 2 ;
10469 mul_mat_f_switch_cols_per_block (
10570 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
10671 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
107- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream (), ids_info_ptr );
72+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream ());
10873 } break ;
10974 case GGML_TYPE_BF16: {
11075 const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data ;
11176 constexpr int vals_per_T = 2 ;
11277 mul_mat_f_switch_cols_per_block (
11378 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
11479 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
115- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream (), ids_info_ptr );
80+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream ());
11681 } break ;
11782 default :
11883 GGML_ABORT (" unsupported type: %s" , ggml_type_name (src0->type ));
@@ -133,9 +98,10 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
13398 }
13499
135100 if (mul_mat_id) {
136- if (src0_ne[ 1 ] <= 1024 && src1_ncols > 512 ) {
101+ if (type == GGML_TYPE_F32 && src1_ncols > 32 ) {
137102 return false ;
138- } else if (src0_ne[1 ] > 1024 && src1_ncols > 128 ) {
103+ }
104+ if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64 ) {
139105 return false ;
140106 }
141107 } else {
0 commit comments