77#include "ggml-cpu-impl.h"
88#include "ggml-cpu.h"
99#include "ggml-impl.h"
10- #include "ggml-quants.h"
1110#include "ggml-cpu-quants.h"
1211#include "ggml-threading.h"
13- #include "amx/amx.h"
1412#include "ggml.h"
1513
1614#if defined(_MSC_VER ) || defined(__MINGW32__ )
@@ -1291,7 +1289,7 @@ struct ggml_threadpool {
12911289 atomic_int n_graph ; // incremented when there is work to be done (i.e each graph)
12921290 atomic_int GGML_CACHE_ALIGN n_barrier ;
12931291 atomic_int GGML_CACHE_ALIGN n_barrier_passed ;
1294- atomic_int current_chunk ; // currently processing chunk during Mat_Mul, shared between all the threads.
1292+ atomic_int GGML_CACHE_ALIGN current_chunk ; // currently processing chunk during Mat_Mul, shared between all the threads.
12951293
12961294 // these are atomic as an annotation for thread-sanitizer
12971295 atomic_bool stop ; // Used for stopping the threadpool altogether
@@ -7490,13 +7488,15 @@ UseGgmlGemm1:;
74907488 if (src1 -> type != vec_dot_type ) {
74917489 char * wdata = params -> wdata ;
74927490
7491+ const size_t nbw0 = ggml_type_size (vec_dot_type );
74937492 const size_t nbw1 = ggml_row_size (vec_dot_type , ne10 );
74947493 const size_t nbw2 = nbw1 * ne11 ;
74957494 const size_t nbw3 = nbw2 * ne12 ;
74967495
74977496 assert (params -> wsize >= ne13 * nbw3 );
74987497 GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
74997498
7499+ #if 0
75007500 for (int64_t i13 = 0 ; i13 < ne13 ; ++ i13 ) {
75017501 for (int64_t i12 = 0 ; i12 < ne12 ; ++ i12 ) {
75027502 for (int64_t i11 = ith ; i11 < ne11 ; i11 += nth ) {
@@ -7506,6 +7506,20 @@ UseGgmlGemm1:;
75067506 }
75077507 }
75087508 }
7509+ #else
7510+ for (int64_t i13 = 0 ; i13 < ne13 ; ++ i13 ) {
7511+ for (int64_t i12 = 0 ; i12 < ne12 ; ++ i12 ) {
7512+ for (int64_t i11 = 0 ; i11 < ne11 ; ++ i11 ) {
7513+ size_t bs = ggml_blck_size (vec_dot_type );
7514+ int64_t ne10_block_start = (ith * ne10 /bs ) / nth ;
7515+ int64_t ne10_block_end = ((ith + 1 ) * ne10 /bs ) / nth ;
7516+ from_float ((float * )((char * ) src1 -> data + i13 * nb13 + i12 * nb12 + i11 * nb11 + ne10_block_start * bs * nb10 ),
7517+ (void * ) (wdata + i13 * nbw3 + i12 * nbw2 + i11 * nbw1 + ne10_block_start * nbw0 ),
7518+ (ne10_block_end - ne10_block_start ) * bs );
7519+ }
7520+ }
7521+ }
7522+ #endif
75097523 }
75107524
75117525 if (ith == 0 ) {
@@ -7593,7 +7607,6 @@ UseGgmlGemm2:;
75937607 if ((nr0 % 2 != 0 ) || (ne11 % 2 != 0 ) || ((ir0_end - ir0_start ) % 2 != 0 ) || ((ir1_end - ir1_start ) % 2 != 0 )) {
75947608 num_rows_per_vec_dot = 1 ;
75957609 }
7596-
75977610 ggml_compute_forward_mul_mat_one_chunk (params , dst , src0 -> type , num_rows_per_vec_dot , ir0_start , ir0_end , ir1_start , ir1_end );
75987611
75997612 if (nth >= nchunk0 * nchunk1 ) {
@@ -7606,6 +7619,84 @@ UseGgmlGemm2:;
76067619
76077620// ggml_compute_forward_mul_mat_id
76087621
7622+ #define MMID_MATRIX_ROW (row_id , i1 ) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
7623+
7624+ struct mmid_row_mapping {
7625+ int32_t i1 ;
7626+ int32_t i2 ;
7627+ };
7628+
7629+ static void ggml_compute_forward_mul_mat_id_one_chunk (
7630+ struct ggml_tensor * dst ,
7631+ const struct ggml_tensor * src0 ,
7632+ const struct ggml_tensor * src1 ,
7633+ const struct ggml_tensor * ids ,
7634+ const int64_t cur_a ,
7635+ const int64_t ir0_start ,
7636+ const int64_t ir0_end ,
7637+ const int64_t ir1_start ,
7638+ const int64_t ir1_end ,
7639+ const char * src0_cur ,
7640+ const struct mmid_row_mapping * matrix_rows ,
7641+ const size_t row_size ,
7642+ const bool src1_cont ,
7643+ const void * wdata ) {
7644+
7645+ GGML_TENSOR_BINARY_OP_LOCALS
7646+
7647+ const enum ggml_type type = src0 -> type ;
7648+
7649+ ggml_vec_dot_t const vec_dot = type_traits_cpu [type ].vec_dot ;
7650+ enum ggml_type const vec_dot_type = type_traits_cpu [type ].vec_dot_type ;
7651+
7652+ const int64_t blck_0 = 16 ;
7653+ const int64_t blck_1 = 16 ;
7654+
7655+ float tmp [16 ];
7656+
7657+ for (int64_t iir1 = ir1_start ; iir1 < ir1_end ; iir1 += blck_1 ) {
7658+ for (int64_t iir0 = ir0_start ; iir0 < ir0_end ; iir0 += blck_0 ) {
7659+ for (int64_t ir1 = iir1 ; ir1 < iir1 + blck_1 && ir1 < ir1_end ; ++ ir1 ) {
7660+ const int64_t _i12 = ir1 ; // logical row index for this expert
7661+
7662+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW (cur_a , _i12 );
7663+ const int id = row_mapping .i1 ; // selected expert index
7664+
7665+ const int64_t i11 = id % ne11 ;
7666+ const int64_t i12 = row_mapping .i2 ; // row index in src1
7667+
7668+ const int64_t i1 = id ; // selected expert index
7669+ const int64_t i2 = i12 ; // row
7670+
7671+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
7672+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
7673+ // the original src1 data pointer, so we should index using the indices directly
7674+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
7675+ const char * src1_col = (const char * ) wdata +
7676+ (src1_cont || src1 -> type != vec_dot_type
7677+ ? (i11 + i12 * ne11 )* row_size
7678+ : (i11 * nb11 + i12 * nb12 ));
7679+
7680+ float * dst_col = (float * ) ((char * ) dst -> data + (i1 * nb1 + i2 * nb2 ));
7681+
7682+ for (int64_t ir0 = iir0 ; ir0 < iir0 + blck_0 && ir0 < ir0_end ; ++ ir0 ) {
7683+ vec_dot (ne00 , & tmp [ir0 - iir0 ], 0 , src0_cur + ir0 * nb01 , 0 , src1_col , 0 , 1 );
7684+ }
7685+
7686+ memcpy (& dst_col [iir0 ], tmp , (MIN (iir0 + blck_0 , ir0_end ) - iir0 )* sizeof (float ));
7687+ }
7688+ }
7689+ }
7690+ }
7691+
7692+ static void * incr_ptr_aligned (void * * p , size_t size , size_t align ) {
7693+
7694+ void * ptr = * p ;
7695+ ptr = (void * ) GGML_PAD ((uintptr_t ) ptr , align );
7696+ * p = (void * ) ((char * ) ptr + size );
7697+ return ptr ;
7698+ }
7699+
76097700static void ggml_compute_forward_mul_mat_id (
76107701 const struct ggml_compute_params * params ,
76117702 struct ggml_tensor * dst ) {
@@ -7623,7 +7714,6 @@ static void ggml_compute_forward_mul_mat_id(
76237714
76247715 const bool src1_cont = ggml_is_contiguous (src1 );
76257716
7626- ggml_vec_dot_t const vec_dot = type_traits_cpu [type ].vec_dot ;
76277717 enum ggml_type const vec_dot_type = type_traits_cpu [type ].vec_dot_type ;
76287718 ggml_from_float_t const from_float = type_traits_cpu [vec_dot_type ].from_float ;
76297719
@@ -7641,41 +7731,60 @@ static void ggml_compute_forward_mul_mat_id(
76417731 const int n_ids = ids -> ne [0 ]; // n_expert_used
76427732 const int n_as = ne02 ; // n_expert
76437733
7644- char * wdata_src1_end = (src1 -> type == vec_dot_type ) ?
7645- (char * ) params -> wdata :
7646- (char * ) params -> wdata + GGML_PAD (ggml_row_size (vec_dot_type , ggml_nelements (src1 )), sizeof (int64_t ));
7734+ void * wdata_cur = params -> wdata ;
76477735
7648- struct mmid_row_mapping {
7649- int32_t i1 ;
7650- int32_t i2 ;
7651- };
7736+ if (src1 -> type != vec_dot_type ) {
7737+ incr_ptr_aligned (& wdata_cur , ggml_row_size (vec_dot_type , ggml_nelements (src1 )), sizeof (int64_t ));
7738+ }
7739+
7740+ int64_t * matrix_row_counts = // [n_as]
7741+ incr_ptr_aligned (& wdata_cur , n_as * sizeof (int64_t ), sizeof (int64_t ));
7742+
7743+ struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
7744+ incr_ptr_aligned (& wdata_cur , n_as * ids -> ne [0 ]* ids -> ne [1 ]* sizeof (struct mmid_row_mapping ), sizeof (int64_t ));
76527745
7653- int64_t * matrix_row_counts = (int64_t * ) (wdata_src1_end ); // [n_as]
7654- struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping * )(matrix_row_counts + n_as ); // [n_as][ne11]
7746+ char (* atomic_current_chunk )[CACHE_LINE_SIZE ] = // [n_as]
7747+ incr_ptr_aligned (& wdata_cur , CACHE_LINE_SIZE * n_as , CACHE_LINE_SIZE );
7748+
7749+ GGML_ASSERT (params -> wsize >= (size_t )((char * ) wdata_cur - (char * ) params -> wdata ));
76557750
76567751 if (src1 -> type != vec_dot_type ) {
76577752 char * wdata = params -> wdata ;
76587753
7754+ const size_t nbw0 = ggml_type_size (vec_dot_type );
76597755 const size_t nbw1 = ggml_row_size (vec_dot_type , ne10 );
76607756 const size_t nbw2 = nbw1 * ne11 ;
76617757 const size_t nbw3 = nbw2 * ne12 ;
76627758
76637759 assert (params -> wsize >= ne13 * nbw3 );
76647760 GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
76657761
7762+ #if 0
76667763 for (int64_t i13 = 0 ; i13 < ne13 ; ++ i13 ) {
7667- for (int64_t i12 = 0 ; i12 < ne12 ; ++ i12 ) {
7668- for (int64_t i11 = ith ; i11 < ne11 ; i11 += nth ) {
7764+ for (int64_t i12 = ith ; i12 < ne12 ; i12 += nth ) {
7765+ for (int64_t i11 = 0 ; i11 < ne11 ; ++ i11 ) {
76697766 from_float ((float * )((char * ) src1 -> data + i13 * nb13 + i12 * nb12 + i11 * nb11 ),
76707767 (void * ) (wdata + i13 * nbw3 + i12 * nbw2 + i11 * nbw1 ),
76717768 ne10 );
76727769 }
76737770 }
76747771 }
7772+ #else
7773+ for (int64_t i13 = 0 ; i13 < ne13 ; ++ i13 ) {
7774+ for (int64_t i12 = 0 ; i12 < ne12 ; ++ i12 ) {
7775+ for (int64_t i11 = 0 ; i11 < ne11 ; ++ i11 ) {
7776+ size_t bs = ggml_blck_size (vec_dot_type );
7777+ int64_t ne10_block_start = (ith * ne10 /bs ) / nth ;
7778+ int64_t ne10_block_end = ((ith + 1 ) * ne10 /bs ) / nth ;
7779+ from_float ((float * )((char * ) src1 -> data + i13 * nb13 + i12 * nb12 + i11 * nb11 + ne10_block_start * bs * nb10 ),
7780+ (void * ) (wdata + i13 * nbw3 + i12 * nbw2 + i11 * nbw1 + ne10_block_start * nbw0 ),
7781+ (ne10_block_end - ne10_block_start ) * bs );
7782+ }
7783+ }
7784+ }
7785+ #endif
76757786 }
76767787
7677- #define MMID_MATRIX_ROW (row_id , i1 ) matrix_rows[(row_id)*ne12 + (i1)]
7678-
76797788 if (ith == 0 ) {
76807789 // initialize matrix_row_counts
76817790 memset (matrix_row_counts , 0 , n_as * sizeof (int64_t ));
@@ -7693,94 +7802,79 @@ static void ggml_compute_forward_mul_mat_id(
76937802 }
76947803 }
76957804
7805+ // reset current_chunk
7806+ for (int cur_a = ith ; cur_a < n_as ; cur_a += nth ) {
7807+ atomic_int * current_chunk_ctr = (atomic_int * )(atomic_current_chunk + cur_a );
7808+ * current_chunk_ctr = nth ;
7809+ }
7810+
76967811 ggml_barrier (params -> threadpool );
76977812
7698- // compute each matrix multiplication in sequence
76997813 for (int cur_a = 0 ; cur_a < n_as ; ++ cur_a ) {
77007814 const int64_t cne1 = matrix_row_counts [cur_a ];
77017815
77027816 if (cne1 == 0 ) {
77037817 continue ;
77047818 }
77057819
7706- const char * src0_cur = (const char * ) src0 -> data + cur_a * nb02 ;
7707-
7708- const void * wdata = (src1 -> type == vec_dot_type ) ? src1 -> data : params -> wdata ;
7820+ const char * src0_cur = (const char * ) src0 -> data + cur_a * nb02 ;
7821+ const void * wdata = (src1 -> type == vec_dot_type ) ? src1 -> data : params -> wdata ;
77097822 const size_t row_size = ggml_row_size (vec_dot_type , ne10 );
77107823
7711- const int64_t nr0 = ne01 ; // src0 rows
7712- const int64_t nr1 = cne1 ; // src1 rows
7713-
7714- // distribute the thread work across the inner or outer loop based on which one is larger
7715-
7716- const int64_t nth0 = nr0 > nr1 ? nth : 1 ; // parallelize by src0 rows
7717- const int64_t nth1 = nr0 > nr1 ? 1 : nth ; // parallelize by src1 rows
7718-
7719- const int64_t ith0 = ith % nth0 ;
7720- const int64_t ith1 = ith / nth0 ;
7721-
7722- const int64_t dr0 = (nr0 + nth0 - 1 )/nth0 ;
7723- const int64_t dr1 = (nr1 + nth1 - 1 )/nth1 ;
7724-
7725- const int64_t ir010 = dr0 * ith0 ;
7726- const int64_t ir011 = MIN (ir010 + dr0 , nr0 );
7824+ const int64_t nr0 = ne01 ;
7825+ const int64_t nr1 = cne1 ;
77277826
7728- const int64_t ir110 = dr1 * ith1 ;
7729- const int64_t ir111 = MIN (ir110 + dr1 , nr1 );
7730-
7731- // threads with no work simply yield (not sure if it helps)
7732- //if (ir010 >= ir011 || ir110 >= ir111) {
7733- // sched_yield();
7734- // continue;
7735- //}
7827+ int chunk_size = 16 ;
7828+ if (nr0 == 1 || nr1 == 1 ) {
7829+ chunk_size = 64 ;
7830+ }
77367831
7737- // block-tiling attempt
7738- const int64_t blck_0 = 16 ;
7739- const int64_t blck_1 = 16 ;
7832+ #if defined(__aarch64__ )
7833+ // disable for ARM
7834+ const bool disable_chunking = true;
7835+ #else
7836+ // disable for NUMA
7837+ const bool disable_chunking = ggml_is_numa ();
7838+ #endif // defined(__aarch64__)
77407839
7741- // attempt to reduce false-sharing (does not seem to make a difference)
7742- float tmp [ 16 ] ;
7840+ int64_t nchunk0 = ( nr0 + chunk_size - 1 ) / chunk_size ;
7841+ int64_t nchunk1 = ( nr1 + chunk_size - 1 ) / chunk_size ;
77437842
7744- for ( int64_t iir1 = ir110 ; iir1 < ir111 ; iir1 += blck_1 ) {
7745- for ( int64_t iir0 = ir010 ; iir0 < ir011 ; iir0 += blck_0 ) {
7746- for ( int64_t ir1 = iir1 ; ir1 < iir1 + blck_1 && ir1 < ir111 ; ++ ir1 ) {
7747- const int64_t _i12 = ir1 ; // logical row index for this expert
7843+ if ( nchunk0 * nchunk1 < nth * 4 || disable_chunking ) {
7844+ nchunk0 = nr0 > nr1 ? nth : 1 ;
7845+ nchunk1 = nr0 > nr1 ? 1 : nth ;
7846+ }
77487847
7749- struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW ( cur_a , _i12 ) ;
7750- const int id = row_mapping . i1 ; // selected expert index
7848+ const int64_t dr0 = ( nr0 + nchunk0 - 1 ) / nchunk0 ;
7849+ const int64_t dr1 = ( nr1 + nchunk1 - 1 ) / nchunk1 ;
77517850
7752- const int64_t i11 = id % ne11 ;
7753- const int64_t i12 = row_mapping .i2 ; // row index in src1
7851+ int current_chunk = ith ;
77547852
7755- const int64_t i1 = id ; // selected expert index
7756- const int64_t i2 = i12 ; // row
7853+ atomic_int * current_chunk_ctr = (atomic_int * )(atomic_current_chunk + cur_a );
77577854
7758- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
7759- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
7760- // the original src1 data pointer, so we should index using the indices directly
7761- // TODO: this is a bit of a hack, we should probably have a better way to handle this
7762- const char * src1_col = (const char * ) wdata +
7763- (src1_cont || src1 -> type != vec_dot_type
7764- ? (i11 + i12 * ne11 )* row_size
7765- : (i11 * nb11 + i12 * nb12 ));
7855+ while (current_chunk < nchunk0 * nchunk1 ) {
7856+ const int64_t ith0 = current_chunk % nchunk0 ;
7857+ const int64_t ith1 = current_chunk / nchunk0 ;
77667858
7767- float * dst_col = (float * ) ((char * ) dst -> data + (i1 * nb1 + i2 * nb2 ));
7859+ const int64_t ir0_start = dr0 * ith0 ;
7860+ const int64_t ir0_end = MIN (ir0_start + dr0 , nr0 );
77687861
7769- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
7770- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
7771- //}
7862+ const int64_t ir1_start = dr1 * ith1 ;
7863+ const int64_t ir1_end = MIN (ir1_start + dr1 , nr1 );
77727864
7773- for (int64_t ir0 = iir0 ; ir0 < iir0 + blck_0 && ir0 < ir011 ; ++ ir0 ) {
7774- vec_dot (ne00 , & tmp [ir0 - iir0 ], 0 , src0_cur + ir0 * nb01 , 0 , src1_col , 0 , 1 );
7775- }
7865+ ggml_compute_forward_mul_mat_id_one_chunk (
7866+ dst , src0 , src1 , ids , cur_a ,
7867+ ir0_start , ir0_end , ir1_start , ir1_end ,
7868+ src0_cur , matrix_rows , row_size , src1_cont , wdata
7869+ );
77767870
7777- memcpy ( & dst_col [ iir0 ], tmp , ( MIN ( iir0 + blck_0 , ir011 ) - iir0 ) * sizeof ( float ));
7778- }
7871+ if ( nth >= nchunk0 * nchunk1 ) {
7872+ break ;
77797873 }
7874+
7875+ current_chunk = atomic_fetch_add_explicit (current_chunk_ctr , 1 , memory_order_relaxed );
77807876 }
77817877 }
7782-
7783- #undef MMID_MATRIX_ROW
77847878}
77857879
77867880// ggml_compute_forward_out_prod
@@ -13713,14 +13807,19 @@ struct ggml_cplan ggml_graph_plan(
1371313807 cur = 0 ;
1371413808 const struct ggml_tensor * src0 = node -> src [0 ];
1371513809 const struct ggml_tensor * src1 = node -> src [1 ];
13810+ const struct ggml_tensor * ids = node -> src [2 ];
1371613811 const enum ggml_type vec_dot_type = type_traits_cpu [src0 -> type ].vec_dot_type ;
13812+ const int n_as = src0 -> ne [2 ];
13813+ // src1
1371713814 if (src1 -> type != vec_dot_type ) {
13718- cur += ggml_row_size (vec_dot_type , ggml_nelements (src1 ));
13815+ cur += ggml_row_size (vec_dot_type , ggml_nelements (src1 )) + sizeof ( int64_t ) ;
1371913816 }
13720- const int n_as = src0 -> ne [2 ];
13721- cur += GGML_PAD (cur , sizeof (int64_t )); // align
13722- cur += n_as * sizeof (int64_t ); // matrix_row_counts
13723- cur += n_as * src1 -> ne [2 ] * sizeof (int64_t ); // matrix_rows
13817+ // matrix_row_counts
13818+ cur += n_as * sizeof (int64_t ) + sizeof (int64_t );
13819+ // matrix_rows
13820+ cur += n_as * ids -> ne [0 ]* ids -> ne [1 ]* sizeof (struct mmid_row_mapping ) + sizeof (int64_t );
13821+ // atomic_current_chunk
13822+ cur += CACHE_LINE_SIZE * n_as + CACHE_LINE_SIZE ;
1372413823 } break ;
1372513824 case GGML_OP_OUT_PROD :
1372613825 {
0 commit comments