11#include " ggml.h"
22#include " mmf.cuh"
3+ #include " mmid.cuh"
4+
35
46void ggml_cuda_mul_mat_f (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
57 GGML_ASSERT ( src1->type == GGML_TYPE_F32);
@@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
3739 const int64_t ids_s0 = ids ? ids->nb [0 ] / ggml_type_size (ids->type ) : 0 ;
3840 const int64_t ids_s1 = ids ? ids->nb [1 ] / ggml_type_size (ids->type ) : 0 ;
3941
42+ mmf_ids_data ids_info{};
43+ mmf_ids_data * ids_info_ptr = nullptr ;
44+ ggml_cuda_pool_alloc<int32_t > ids_src_compact_dev;
45+ ggml_cuda_pool_alloc<int32_t > ids_dst_compact_dev;
46+ ggml_cuda_pool_alloc<int32_t > expert_bounds_dev;
47+
4048 // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
4149 const int64_t ncols_dst = ids ? ne2 : ne1;
4250 const int64_t nchannels_dst = ids ? ne1 : ne2;
@@ -54,30 +62,57 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
5462 nchannels_y = ids->ne [0 ];
5563 }
5664
65+ if (ids && ncols_dst > 16 ) {
66+ const int64_t n_expert_used = ids->ne [0 ];
67+ const int64_t n_experts = ne02;
68+ const int64_t n_tokens = ne12;
69+ const int64_t ne_get_rows = n_tokens * n_expert_used;
70+
71+ ids_src_compact_dev.alloc (ctx.pool (), ne_get_rows);
72+ ids_dst_compact_dev.alloc (ctx.pool (), ne_get_rows);
73+ expert_bounds_dev.alloc (ctx.pool (), n_experts + 1 );
74+
75+ const int si1 = static_cast <int >(ids_s1);
76+ const int sis1 = static_cast <int >(src1->nb [2 ] / src1->nb [1 ]);
77+
78+ GGML_ASSERT (sis1 > 0 );
79+
80+ ggml_cuda_launch_mm_ids_helper (ids_d, ids_src_compact_dev.get (), ids_dst_compact_dev.get (), expert_bounds_dev.get (),
81+ static_cast <int >(n_experts), static_cast <int >(n_tokens), static_cast <int >(n_expert_used), static_cast <int >(ne11), si1, sis1, ctx.stream ());
82+ CUDA_CHECK (cudaGetLastError ());
83+
84+ ids_info.ids_src_compact = ids_src_compact_dev.get ();
85+ ids_info.ids_dst_compact = ids_dst_compact_dev.get ();
86+ ids_info.expert_bounds_dev = expert_bounds_dev.get ();
87+ ids_info.n_experts = static_cast <int >(n_experts);
88+ ids_info.sis1 = sis1;
89+ ids_info_ptr = &ids_info;
90+ }
91+
5792 switch (src0->type ) {
5893 case GGML_TYPE_F32: {
5994 const float * src0_d = (const float *) src0->data ;
6095 constexpr int vals_per_T = 1 ;
6196 mul_mat_f_switch_cols_per_block (
6297 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
6398 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
64- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream ());
99+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream (), ids_info_ptr );
65100 } break ;
66101 case GGML_TYPE_F16: {
67102 const half2 * src0_d = (const half2 *) src0->data ;
68103 constexpr int vals_per_T = 2 ;
69104 mul_mat_f_switch_cols_per_block (
70105 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
71106 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
72- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream ());
107+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream (), ids_info_ptr );
73108 } break ;
74109 case GGML_TYPE_BF16: {
75110 const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data ;
76111 constexpr int vals_per_T = 2 ;
77112 mul_mat_f_switch_cols_per_block (
78113 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
79114 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
80- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream ());
115+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream (), ids_info_ptr );
81116 } break ;
82117 default :
83118 GGML_ABORT (" unsupported type: %s" , ggml_type_name (src0->type ));
@@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
98133 }
99134
100135 if (mul_mat_id) {
101- if (type == GGML_TYPE_F32 && src1_ncols > 32 ) {
136+ if (src0_ne[ 1 ] <= 1024 && src1_ncols > 512 ) {
102137 return false ;
103- }
104- if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64 ) {
138+ } else if (src0_ne[1 ] > 1024 && src1_ncols > 128 ) {
105139 return false ;
106140 }
107141 } else {
0 commit comments