Skip to content

Commit c710560

Browse files
committed
cleanup
1 parent 37f8f84 commit c710560

File tree

5 files changed

+22
-213
lines changed

5 files changed

+22
-213
lines changed

ggml/src/ggml-cuda/mmf.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
#include "ggml.h"
22
#include "mmf.cuh"
3+
#include "mmid.cuh"
34

4-
void ggml_cuda_launch_mmq_ids_helper(
5-
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
6-
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);
75

86
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
97
GGML_ASSERT( src1->type == GGML_TYPE_F32);
@@ -79,7 +77,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
7977

8078
GGML_ASSERT(sis1 > 0);
8179

82-
ggml_cuda_launch_mmq_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
80+
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
8381
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
8482
CUDA_CHECK(cudaGetLastError());
8583

@@ -88,7 +86,6 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
8886
ids_info.expert_bounds_dev = expert_bounds_dev.get();
8987
ids_info.n_experts = static_cast<int>(n_experts);
9088
ids_info.sis1 = sis1;
91-
ids_info.cols_per_tile = 16;
9289
ids_info_ptr = &ids_info;
9390
}
9491

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ struct mmf_ids_data {
1313
const int32_t * expert_bounds_dev = nullptr;
1414
int n_experts = 0;
1515
int sis1 = 0;
16-
int cols_per_tile = 0;
1716
};
1817

1918
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
@@ -502,8 +501,6 @@ static inline void mul_mat_f_switch_ids(
502501
if (max_tiles == 0) {
503502
return;
504503
}
505-
GGML_ASSERT(ids_data->cols_per_tile == 0 || ids_data->cols_per_tile == cols_per_block);
506-
507504
dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
508505

509506
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
@@ -658,104 +655,86 @@ static void mul_mat_f_switch_cols_per_block(
658655

659656
GGML_ASSERT(ids || ncols_dst <= 16);
660657

661-
mmf_ids_data ids_case;
662-
auto prepare_ids = [&](int cols_case) -> const mmf_ids_data * {
663-
if (!ids_data || !ids_data->ids_src_compact) {
664-
return nullptr;
665-
}
666-
667-
if (ids_data->cols_per_tile != 0 && ids_data->cols_per_tile != cols_case) {
668-
return nullptr;
669-
}
670-
671-
ids_case = *ids_data;
672-
ids_case.cols_per_tile = cols_case;
673-
return &ids_case;
674-
};
675-
676-
mmf_ids_data * ids_case_ptr = nullptr;
677-
678658
switch (ncols_case) {
679659
case 1: {
680660
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
681661
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
682-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
662+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
683663
} break;
684664
case 2: {
685665
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
686666
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
687-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
667+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
688668
} break;
689669
case 3: {
690670
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
691671
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
692-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
672+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
693673
} break;
694674
case 4: {
695675
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
696676
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
697-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
677+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
698678
} break;
699679
case 5: {
700680
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
701681
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
702-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
682+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
703683
} break;
704684
case 6: {
705685
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
706686
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
707-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
687+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
708688
} break;
709689
case 7: {
710690
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
711691
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
712-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
692+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
713693
} break;
714694
case 8: {
715695
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
716696
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
717-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
697+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
718698
} break;
719699
case 9: {
720700
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
721701
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
722-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
702+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
723703
} break;
724704
case 10: {
725705
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
726706
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
727-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
707+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
728708
} break;
729709
case 11: {
730710
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
731711
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
732-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
712+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
733713
} break;
734714
case 12: {
735715
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
736716
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
737-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
717+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
738718
} break;
739719
case 13: {
740720
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
741721
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
742-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
722+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
743723
} break;
744724
case 14: {
745725
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
746726
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
747-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
727+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
748728
} break;
749729
case 15: {
750730
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
751731
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
752-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
732+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
753733
} break;
754734
case 16: {
755-
const mmf_ids_data * ids_case_ptr = prepare_ids(16);
756735
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
757736
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
758-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr);
737+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
759738
} break;
760739
default: {
761740
GGML_ABORT("fatal error");

ggml/src/ggml-cuda/mmid.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#include "common.cuh"
1+
#pragma once
22

33
void ggml_cuda_launch_mm_ids_helper(
44
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
5-
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);
5+
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);

0 commit comments

Comments
 (0)