Skip to content

Commit a394039

Browse files
authored
ggml-cpu : add chunking support to mul_mat_id (#11666)
* ggml-cpu : add chunking support to mul_mat_id * allocate chunk counter in wdata parallelize src1 quantization by column to allows parallelization even when there is only one row * disable for arm * cleanup * better way to disable for arm * fix uninitialized counter when using 1 thread only * revert test-backend-ops changes
1 parent be3bbd6 commit a394039

File tree

1 file changed

+184
-85
lines changed

1 file changed

+184
-85
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 184 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
#include "ggml-cpu-impl.h"
88
#include "ggml-cpu.h"
99
#include "ggml-impl.h"
10-
#include "ggml-quants.h"
1110
#include "ggml-cpu-quants.h"
1211
#include "ggml-threading.h"
13-
#include "amx/amx.h"
1412
#include "ggml.h"
1513

1614
#if defined(_MSC_VER) || defined(__MINGW32__)
@@ -1291,7 +1289,7 @@ struct ggml_threadpool {
12911289
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
12921290
atomic_int GGML_CACHE_ALIGN n_barrier;
12931291
atomic_int GGML_CACHE_ALIGN n_barrier_passed;
1294-
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1292+
atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
12951293

12961294
// these are atomic as an annotation for thread-sanitizer
12971295
atomic_bool stop; // Used for stopping the threadpool altogether
@@ -7490,13 +7488,15 @@ UseGgmlGemm1:;
74907488
if (src1->type != vec_dot_type) {
74917489
char * wdata = params->wdata;
74927490

7491+
const size_t nbw0 = ggml_type_size(vec_dot_type);
74937492
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
74947493
const size_t nbw2 = nbw1*ne11;
74957494
const size_t nbw3 = nbw2*ne12;
74967495

74977496
assert(params->wsize >= ne13*nbw3);
74987497
GGML_ASSERT(src1->type == GGML_TYPE_F32);
74997498

7499+
#if 0
75007500
for (int64_t i13 = 0; i13 < ne13; ++i13) {
75017501
for (int64_t i12 = 0; i12 < ne12; ++i12) {
75027502
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
@@ -7506,6 +7506,20 @@ UseGgmlGemm1:;
75067506
}
75077507
}
75087508
}
7509+
#else
7510+
for (int64_t i13 = 0; i13 < ne13; ++i13) {
7511+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
7512+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
7513+
size_t bs = ggml_blck_size(vec_dot_type);
7514+
int64_t ne10_block_start = (ith * ne10/bs) / nth;
7515+
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
7516+
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
7517+
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
7518+
(ne10_block_end - ne10_block_start) * bs);
7519+
}
7520+
}
7521+
}
7522+
#endif
75097523
}
75107524

75117525
if (ith == 0) {
@@ -7593,7 +7607,6 @@ UseGgmlGemm2:;
75937607
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
75947608
num_rows_per_vec_dot = 1;
75957609
}
7596-
75977610
ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
75987611

75997612
if (nth >= nchunk0 * nchunk1) {
@@ -7606,6 +7619,84 @@ UseGgmlGemm2:;
76067619

76077620
// ggml_compute_forward_mul_mat_id
76087621

7622+
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
7623+
7624+
struct mmid_row_mapping {
7625+
int32_t i1;
7626+
int32_t i2;
7627+
};
7628+
7629+
static void ggml_compute_forward_mul_mat_id_one_chunk(
7630+
struct ggml_tensor * dst,
7631+
const struct ggml_tensor * src0,
7632+
const struct ggml_tensor * src1,
7633+
const struct ggml_tensor * ids,
7634+
const int64_t cur_a,
7635+
const int64_t ir0_start,
7636+
const int64_t ir0_end,
7637+
const int64_t ir1_start,
7638+
const int64_t ir1_end,
7639+
const char * src0_cur,
7640+
const struct mmid_row_mapping * matrix_rows,
7641+
const size_t row_size,
7642+
const bool src1_cont,
7643+
const void * wdata) {
7644+
7645+
GGML_TENSOR_BINARY_OP_LOCALS
7646+
7647+
const enum ggml_type type = src0->type;
7648+
7649+
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
7650+
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7651+
7652+
const int64_t blck_0 = 16;
7653+
const int64_t blck_1 = 16;
7654+
7655+
float tmp[16];
7656+
7657+
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
7658+
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
7659+
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
7660+
const int64_t _i12 = ir1; // logical row index for this expert
7661+
7662+
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
7663+
const int id = row_mapping.i1; // selected expert index
7664+
7665+
const int64_t i11 = id % ne11;
7666+
const int64_t i12 = row_mapping.i2; // row index in src1
7667+
7668+
const int64_t i1 = id; // selected expert index
7669+
const int64_t i2 = i12; // row
7670+
7671+
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
7672+
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
7673+
// the original src1 data pointer, so we should index using the indices directly
7674+
// TODO: this is a bit of a hack, we should probably have a better way to handle this
7675+
const char * src1_col = (const char *) wdata +
7676+
(src1_cont || src1->type != vec_dot_type
7677+
? (i11 + i12*ne11)*row_size
7678+
: (i11*nb11 + i12*nb12));
7679+
7680+
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
7681+
7682+
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
7683+
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
7684+
}
7685+
7686+
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
7687+
}
7688+
}
7689+
}
7690+
}
7691+
7692+
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
7693+
7694+
void * ptr = *p;
7695+
ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
7696+
*p = (void *) ((char *) ptr + size);
7697+
return ptr;
7698+
}
7699+
76097700
static void ggml_compute_forward_mul_mat_id(
76107701
const struct ggml_compute_params * params,
76117702
struct ggml_tensor * dst) {
@@ -7623,7 +7714,6 @@ static void ggml_compute_forward_mul_mat_id(
76237714

76247715
const bool src1_cont = ggml_is_contiguous(src1);
76257716

7626-
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
76277717
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
76287718
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
76297719

@@ -7641,41 +7731,60 @@ static void ggml_compute_forward_mul_mat_id(
76417731
const int n_ids = ids->ne[0]; // n_expert_used
76427732
const int n_as = ne02; // n_expert
76437733

7644-
char * wdata_src1_end = (src1->type == vec_dot_type) ?
7645-
(char *) params->wdata :
7646-
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
7734+
void * wdata_cur = params->wdata;
76477735

7648-
struct mmid_row_mapping {
7649-
int32_t i1;
7650-
int32_t i2;
7651-
};
7736+
if (src1->type != vec_dot_type) {
7737+
incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
7738+
}
7739+
7740+
int64_t * matrix_row_counts = // [n_as]
7741+
incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
7742+
7743+
struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
7744+
incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
76527745

7653-
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
7654-
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
7746+
char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
7747+
incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
7748+
7749+
GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
76557750

76567751
if (src1->type != vec_dot_type) {
76577752
char * wdata = params->wdata;
76587753

7754+
const size_t nbw0 = ggml_type_size(vec_dot_type);
76597755
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
76607756
const size_t nbw2 = nbw1*ne11;
76617757
const size_t nbw3 = nbw2*ne12;
76627758

76637759
assert(params->wsize >= ne13*nbw3);
76647760
GGML_ASSERT(src1->type == GGML_TYPE_F32);
76657761

7762+
#if 0
76667763
for (int64_t i13 = 0; i13 < ne13; ++i13) {
7667-
for (int64_t i12 = 0; i12 < ne12; ++i12) {
7668-
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
7764+
for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
7765+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
76697766
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
76707767
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
76717768
ne10);
76727769
}
76737770
}
76747771
}
7772+
#else
7773+
for (int64_t i13 = 0; i13 < ne13; ++i13) {
7774+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
7775+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
7776+
size_t bs = ggml_blck_size(vec_dot_type);
7777+
int64_t ne10_block_start = (ith * ne10/bs) / nth;
7778+
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
7779+
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
7780+
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
7781+
(ne10_block_end - ne10_block_start) * bs);
7782+
}
7783+
}
7784+
}
7785+
#endif
76757786
}
76767787

7677-
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
7678-
76797788
if (ith == 0) {
76807789
// initialize matrix_row_counts
76817790
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -7693,94 +7802,79 @@ static void ggml_compute_forward_mul_mat_id(
76937802
}
76947803
}
76957804

7805+
// reset current_chunk
7806+
for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
7807+
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
7808+
*current_chunk_ctr = nth;
7809+
}
7810+
76967811
ggml_barrier(params->threadpool);
76977812

7698-
// compute each matrix multiplication in sequence
76997813
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
77007814
const int64_t cne1 = matrix_row_counts[cur_a];
77017815

77027816
if (cne1 == 0) {
77037817
continue;
77047818
}
77057819

7706-
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
7707-
7708-
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
7820+
const char * src0_cur = (const char *) src0->data + cur_a * nb02;
7821+
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
77097822
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
77107823

7711-
const int64_t nr0 = ne01; // src0 rows
7712-
const int64_t nr1 = cne1; // src1 rows
7713-
7714-
// distribute the thread work across the inner or outer loop based on which one is larger
7715-
7716-
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
7717-
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
7718-
7719-
const int64_t ith0 = ith % nth0;
7720-
const int64_t ith1 = ith / nth0;
7721-
7722-
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
7723-
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
7724-
7725-
const int64_t ir010 = dr0*ith0;
7726-
const int64_t ir011 = MIN(ir010 + dr0, nr0);
7824+
const int64_t nr0 = ne01;
7825+
const int64_t nr1 = cne1;
77277826

7728-
const int64_t ir110 = dr1*ith1;
7729-
const int64_t ir111 = MIN(ir110 + dr1, nr1);
7730-
7731-
// threads with no work simply yield (not sure if it helps)
7732-
//if (ir010 >= ir011 || ir110 >= ir111) {
7733-
// sched_yield();
7734-
// continue;
7735-
//}
7827+
int chunk_size = 16;
7828+
if (nr0 == 1 || nr1 == 1) {
7829+
chunk_size = 64;
7830+
}
77367831

7737-
// block-tiling attempt
7738-
const int64_t blck_0 = 16;
7739-
const int64_t blck_1 = 16;
7832+
#if defined(__aarch64__)
7833+
// disable for ARM
7834+
const bool disable_chunking = true;
7835+
#else
7836+
// disable for NUMA
7837+
const bool disable_chunking = ggml_is_numa();
7838+
#endif // defined(__aarch64__)
77407839

7741-
// attempt to reduce false-sharing (does not seem to make a difference)
7742-
float tmp[16];
7840+
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
7841+
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
77437842

7744-
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
7745-
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
7746-
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
7747-
const int64_t _i12 = ir1; // logical row index for this expert
7843+
if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
7844+
nchunk0 = nr0 > nr1 ? nth : 1;
7845+
nchunk1 = nr0 > nr1 ? 1 : nth;
7846+
}
77487847

7749-
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
7750-
const int id = row_mapping.i1; // selected expert index
7848+
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
7849+
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
77517850

7752-
const int64_t i11 = id % ne11;
7753-
const int64_t i12 = row_mapping.i2; // row index in src1
7851+
int current_chunk = ith;
77547852

7755-
const int64_t i1 = id; // selected expert index
7756-
const int64_t i2 = i12; // row
7853+
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
77577854

7758-
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
7759-
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
7760-
// the original src1 data pointer, so we should index using the indices directly
7761-
// TODO: this is a bit of a hack, we should probably have a better way to handle this
7762-
const char * src1_col = (const char *) wdata +
7763-
(src1_cont || src1->type != vec_dot_type
7764-
? (i11 + i12*ne11)*row_size
7765-
: (i11*nb11 + i12*nb12));
7855+
while (current_chunk < nchunk0 * nchunk1) {
7856+
const int64_t ith0 = current_chunk % nchunk0;
7857+
const int64_t ith1 = current_chunk / nchunk0;
77667858

7767-
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
7859+
const int64_t ir0_start = dr0 * ith0;
7860+
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
77687861

7769-
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
7770-
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
7771-
//}
7862+
const int64_t ir1_start = dr1 * ith1;
7863+
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
77727864

7773-
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
7774-
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
7775-
}
7865+
ggml_compute_forward_mul_mat_id_one_chunk(
7866+
dst, src0, src1, ids, cur_a,
7867+
ir0_start, ir0_end, ir1_start, ir1_end,
7868+
src0_cur, matrix_rows, row_size, src1_cont, wdata
7869+
);
77767870

7777-
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
7778-
}
7871+
if (nth >= nchunk0 * nchunk1) {
7872+
break;
77797873
}
7874+
7875+
current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
77807876
}
77817877
}
7782-
7783-
#undef MMID_MATRIX_ROW
77847878
}
77857879

77867880
// ggml_compute_forward_out_prod
@@ -13713,14 +13807,19 @@ struct ggml_cplan ggml_graph_plan(
1371313807
cur = 0;
1371413808
const struct ggml_tensor * src0 = node->src[0];
1371513809
const struct ggml_tensor * src1 = node->src[1];
13810+
const struct ggml_tensor * ids = node->src[2];
1371613811
const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
13812+
const int n_as = src0->ne[2];
13813+
// src1
1371713814
if (src1->type != vec_dot_type) {
13718-
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
13815+
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t);
1371913816
}
13720-
const int n_as = src0->ne[2];
13721-
cur += GGML_PAD(cur, sizeof(int64_t)); // align
13722-
cur += n_as * sizeof(int64_t); // matrix_row_counts
13723-
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
13817+
// matrix_row_counts
13818+
cur += n_as * sizeof(int64_t) + sizeof(int64_t);
13819+
// matrix_rows
13820+
cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
13821+
// atomic_current_chunk
13822+
cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
1372413823
} break;
1372513824
case GGML_OP_OUT_PROD:
1372613825
{

0 commit comments

Comments
 (0)