@@ -737,46 +737,45 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
737737
738738 auto etypeA = ggml_type (typeA);
739739 if (auto dequant_type = MulMat::is_dequant_better (etypeA, Ny); dequant_type != etypeA) {
740- if (!MulMat::prepare (dequant_type, typeB, ne00, mm, Ny)) {
741- return false ;
742- }
740+ if (MulMat::prepare (dequant_type, typeB, ne00, mm, Ny)) {
743741
744- constexpr int k_x_step = 64 ;
742+ constexpr int k_x_step = 64 ;
745743
746- auto num_rows = MulMat::num_rows (ggml_type (dequant_type));
747- GGML_ASSERT (Nx%num_rows == 0 );
748- auto nrc_x = (Nx/num_rows + nth - 1 )/nth;
749- auto first_x = ith*nrc_x;
750- if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
751- first_x *= num_rows;
752- nrc_x *= num_rows;
744+ auto num_rows = MulMat::num_rows (ggml_type (dequant_type));
745+ GGML_ASSERT (Nx%num_rows == 0 );
746+ auto nrc_x = (Nx/num_rows + nth - 1 )/nth;
747+ auto first_x = ith*nrc_x;
748+ if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
749+ first_x *= num_rows;
750+ nrc_x *= num_rows;
753751
754- size_t row_size_qx = ggml_row_size (dequant_type, ne00);
755- size_t row_size_qy = strideB;
752+ size_t row_size_qx = ggml_row_size (dequant_type, ne00);
753+ size_t row_size_qy = strideB;
756754
757- DataInfo info{C + first_x, (const char *)B, nb1/sizeof (float ), row_size_qy, 0 , ne11, row_mapping, nb2/sizeof (float )};
755+ DataInfo info{C + first_x, (const char *)B, nb1/sizeof (float ), row_size_qy, 0 , ne11, row_mapping, nb2/sizeof (float )};
758756
759- auto & f = thread_local_work_buffer ();
757+ auto & f = thread_local_work_buffer ();
760758
761- for (int ix = 0 ; ix < nrc_x; ix += k_x_step) {
762- auto this_info = info;
763- this_info.s += ix;
764- int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
765- if (f.size () < 2 *row_size_qx*this_nrc_x) f.resize (2 *row_size_qx*this_nrc_x);
766- auto Xu = f.data ();
767- auto Xg = f.data () + row_size_qx*this_nrc_x;
768- if (!iqk_convert_repack (typeA, ne00, (const char *)Aup + (first_x + ix)*strideA, strideA, Xu, ne00, this_nrc_x)) {
769- GGML_ABORT (" Fatal error" );
770- }
771- if (!iqk_convert_repack (typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
772- GGML_ABORT (" Fatal error" );
759+ for (int ix = 0 ; ix < nrc_x; ix += k_x_step) {
760+ auto this_info = info;
761+ this_info.s += ix;
762+ int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
763+ if (f.size () < 2 *row_size_qx*this_nrc_x) f.resize (2 *row_size_qx*this_nrc_x);
764+ auto Xu = f.data ();
765+ auto Xg = f.data () + row_size_qx*this_nrc_x;
766+ if (!iqk_convert_repack (typeA, ne00, (const char *)Aup + (first_x + ix)*strideA, strideA, Xu, ne00, this_nrc_x)) {
767+ GGML_ABORT (" Fatal error" );
768+ }
769+ if (!iqk_convert_repack (typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
770+ GGML_ABORT (" Fatal error" );
771+ }
772+ auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr ;
773+ auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr ;
774+ mm.mul_mat_up_gate_NxM (ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op);
773775 }
774- auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr ;
775- auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr ;
776- mm.mul_mat_up_gate_NxM (ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op);
777- }
778776
779- return true ;
777+ return true ;
778+ }
780779
781780 }
782781
0 commit comments