@@ -13,7 +13,6 @@ struct mmf_ids_data {
1313 const int32_t * expert_bounds_dev = nullptr ;
1414 int n_experts = 0 ;
1515 int sis1 = 0 ;
16- int cols_per_tile = 0 ;
1716};
1817
1918void ggml_cuda_mul_mat_f (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
@@ -502,8 +501,6 @@ static inline void mul_mat_f_switch_ids(
502501 if (max_tiles == 0 ) {
503502 return ;
504503 }
505- GGML_ASSERT (ids_data->cols_per_tile == 0 || ids_data->cols_per_tile == cols_per_block);
506-
507504 dim3 block_nums_ids (block_nums.x , ids_data->n_experts , max_tiles);
508505
509506 const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values ((uint32_t ) ids_data->sis1 ) : make_uint3 (0 , 0 , 1 );
@@ -658,104 +655,86 @@ static void mul_mat_f_switch_cols_per_block(
658655
659656 GGML_ASSERT (ids || ncols_dst <= 16 );
660657
661- mmf_ids_data ids_case;
662- auto prepare_ids = [&](int cols_case) -> const mmf_ids_data * {
663- if (!ids_data || !ids_data->ids_src_compact ) {
664- return nullptr ;
665- }
666-
667- if (ids_data->cols_per_tile != 0 && ids_data->cols_per_tile != cols_case) {
668- return nullptr ;
669- }
670-
671- ids_case = *ids_data;
672- ids_case.cols_per_tile = cols_case;
673- return &ids_case;
674- };
675-
676- mmf_ids_data * ids_case_ptr = nullptr ;
677-
678658 switch (ncols_case) {
679659 case 1 : {
680660 mul_mat_f_cuda<T, 1 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
681661 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
682- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
662+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
683663 } break ;
684664 case 2 : {
685665 mul_mat_f_cuda<T, 2 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
686666 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
687- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
667+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
688668 } break ;
689669 case 3 : {
690670 mul_mat_f_cuda<T, 3 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
691671 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
692- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
672+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
693673 } break ;
694674 case 4 : {
695675 mul_mat_f_cuda<T, 4 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
696676 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
697- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
677+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
698678 } break ;
699679 case 5 : {
700680 mul_mat_f_cuda<T, 5 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
701681 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
702- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
682+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
703683 } break ;
704684 case 6 : {
705685 mul_mat_f_cuda<T, 6 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
706686 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
707- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
687+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
708688 } break ;
709689 case 7 : {
710690 mul_mat_f_cuda<T, 7 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
711691 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
712- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
692+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
713693 } break ;
714694 case 8 : {
715695 mul_mat_f_cuda<T, 8 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
716696 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
717- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
697+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
718698 } break ;
719699 case 9 : {
720700 mul_mat_f_cuda<T, 9 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
721701 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
722- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
702+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
723703 } break ;
724704 case 10 : {
725705 mul_mat_f_cuda<T, 10 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
726706 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
727- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
707+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
728708 } break ;
729709 case 11 : {
730710 mul_mat_f_cuda<T, 11 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
731711 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
732- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
712+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
733713 } break ;
734714 case 12 : {
735715 mul_mat_f_cuda<T, 12 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
736716 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
737- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
717+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
738718 } break ;
739719 case 13 : {
740720 mul_mat_f_cuda<T, 13 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
741721 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
742- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
722+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
743723 } break ;
744724 case 14 : {
745725 mul_mat_f_cuda<T, 14 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
746726 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
747- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
727+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
748728 } break ;
749729 case 15 : {
750730 mul_mat_f_cuda<T, 15 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
751731 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
752- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
732+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
753733 } break ;
754734 case 16 : {
755- const mmf_ids_data * ids_case_ptr = prepare_ids (16 );
756735 mul_mat_f_cuda<T, 16 >(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
757736 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
758- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr );
737+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data );
759738 } break ;
760739 default : {
761740 GGML_ABORT (" fatal error" );
0 commit comments