@@ -2160,6 +2160,7 @@ struct ggml_compute_params {
21602160
21612161 // work buffer for all threads
21622162 size_t wsize;
2163+ size_t qsize;
21632164 void * wdata;
21642165
21652166 struct ggml_threadpool * threadpool;
@@ -12781,13 +12782,18 @@ UseGgmlGemm1:;
1278112782#endif
1278212783
1278312784 if (src1->type != vec_dot_type) {
12784- char * wdata = params->wdata;
12785+ char * wdata = (char *)params->wdata + params->wsize - params->qsize;
12786+
12787+ if (strncmp(src1->name, wdata - GGML_MAX_NAME, GGML_MAX_NAME) == 0) {
12788+ goto AlreadyQunatized;
12789+ }
12790+ wdata += GGML_MAX_NAME;
1278512791
1278612792 const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
1278712793 const size_t nbw2 = nbw1*ne11;
1278812794 const size_t nbw3 = nbw2*ne12;
1278912795
12790- assert(params->wsize >= ne13*nbw3);
12796+ assert(params->qsize >= ne13*nbw3);
1279112797 GGML_ASSERT(src1->type == GGML_TYPE_F32);
1279212798
1279312799 for (int64_t i13 = 0; i13 < ne13; ++i13) {
@@ -12808,14 +12814,21 @@ UseGgmlGemm1:;
1280812814 }
1280912815 }
1281012816 }
12811- }
1281212817
12813- if (ith == 0) {
12814- // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
12815- atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed);
12818+ ggml_barrier(params->threadpool);
12819+
12820+ if (ith == 0) {
12821+ wdata -= GGML_MAX_NAME;
12822+ memcpy(wdata, src1->name, GGML_MAX_NAME);
12823+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
12824+ atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed);
12825+ }
12826+
12827+ AlreadyQunatized:;
1281612828 }
1281712829
12818- ggml_barrier(params->threadpool);
12830+ const void * wdata = (src1->type == vec_dot_type) ? src1->data
12831+ : (const void *)((const char *)params->wdata + params->wsize - params->qsize + GGML_MAX_NAME);
1281912832
1282012833#if GGML_USE_LLAMAFILE
1282112834 if (src1->type != vec_dot_type) {
@@ -12966,9 +12979,10 @@ static void ggml_compute_forward_mul_mat_id(
1296612979 const int n_ids = ids->ne[0]; // n_expert_used
1296712980 const int n_as = ne02; // n_expert
1296812981
12969- char * wdata_src1_end = (src1->type == vec_dot_type) ?
12970- (char *) params->wdata :
12971- (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
12982+ char * qdata = (char *)params->wdata + params->wsize - params->qsize;
12983+
12984+ char * wdata_src1_end = (src1->type == vec_dot_type) ? qdata :
12985+ qdata + GGML_PAD(GGML_MAX_NAME + ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
1297212986
1297312987 struct mmid_row_mapping {
1297412988 int32_t i1;
@@ -12978,14 +12992,19 @@ static void ggml_compute_forward_mul_mat_id(
1297812992 int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1297912993 struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
1298012994
12995+ bool store_name = false;
1298112996 if (src1->type != vec_dot_type) {
12982- char * wdata = params->wdata;
12997+ if (strncmp(src1->name, qdata, GGML_MAX_NAME) == 0) {
12998+ goto QuantizationAlreadyDone;
12999+ }
13000+ store_name = true;
13001+ char * wdata = qdata + GGML_MAX_NAME;
1298313002
1298413003 const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
1298513004 const size_t nbw2 = nbw1*ne11;
1298613005 const size_t nbw3 = nbw2*ne12;
1298713006
12988- assert(params->wsize >= ne13*nbw3);
13007+ assert(params->qsize >= ne13*nbw3);
1298913008 GGML_ASSERT(src1->type == GGML_TYPE_F32);
1299013009
1299113010 for (int64_t i13 = 0; i13 < ne13; ++i13) {
@@ -13001,7 +13020,12 @@ static void ggml_compute_forward_mul_mat_id(
1300113020
1300213021#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
1300313022
13023+ QuantizationAlreadyDone:;
1300413024 if (ith == 0) {
13025+ if (store_name) {
13026+ memcpy(qdata, src1->name, GGML_MAX_NAME);
13027+ }
13028+
1300513029 // initialize matrix_row_counts
1300613030 memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
1300713031
@@ -13030,7 +13054,7 @@ static void ggml_compute_forward_mul_mat_id(
1303013054
1303113055 const char * src0_cur = (const char *) src0->data + cur_a*nb02;
1303213056
13033- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata ;
13057+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : qdata + GGML_MAX_NAME ;
1303413058 const size_t row_size = ggml_row_size(vec_dot_type, ne10);
1303513059
1303613060 const int64_t nr0 = ne01; // src0 rows
@@ -19983,6 +20007,7 @@ struct ggml_cplan ggml_graph_plan(
1998320007 }
1998420008
1998520009 size_t work_size = 0;
20010+ size_t q_size = 0;
1998620011
1998720012 struct ggml_cplan cplan;
1998820013 memset(&cplan, 0, sizeof(struct ggml_cplan));
@@ -19998,6 +20023,7 @@ struct ggml_cplan ggml_graph_plan(
1999820023 max_tasks = MAX(max_tasks, n_tasks);
1999920024
2000020025 size_t cur = 0;
20026+ size_t cur_q = 0;
2000120027
2000220028 switch (node->op) {
2000320029 case GGML_OP_CPY:
@@ -20037,7 +20063,7 @@ struct ggml_cplan ggml_graph_plan(
2003720063 } else
2003820064#endif
2003920065 if (node->src[1]->type != vec_dot_type) {
20040- cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
20066+ cur_q = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
2004120067 }
2004220068 } break;
2004320069 case GGML_OP_MUL_MAT_ID:
@@ -20047,12 +20073,12 @@ struct ggml_cplan ggml_graph_plan(
2004720073 const struct ggml_tensor * src1 = node->src[1];
2004820074 const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
2004920075 if (src1->type != vec_dot_type) {
20050- cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
20076+ cur_q += ggml_row_size(vec_dot_type, ggml_nelements(src1));
2005120077 }
2005220078 const int n_as = src0->ne[2];
20053- cur += GGML_PAD(cur, sizeof(int64_t)); // align
20054- cur += n_as * sizeof(int64_t); // matrix_row_counts
20055- cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
20079+ cur_q += GGML_PAD(cur, sizeof(int64_t)); // align
20080+ cur_q += n_as * sizeof(int64_t); // matrix_row_counts
20081+ cur_q += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
2005620082 } break;
2005720083 case GGML_OP_OUT_PROD:
2005820084 {
@@ -20141,15 +20167,21 @@ struct ggml_cplan ggml_graph_plan(
2014120167 }
2014220168
2014320169 work_size = MAX(work_size, cur);
20170+ q_size = MAX(q_size, cur_q);
2014420171 }
2014520172
2014620173 if (work_size > 0) {
2014720174 work_size += CACHE_LINE_SIZE*(n_threads);
2014820175 }
20176+ if (q_size > 0) {
20177+ q_size += GGML_MAX_NAME;
20178+ }
20179+ work_size += q_size;
2014920180
2015020181 cplan.threadpool = threadpool;
2015120182 cplan.n_threads = MIN(max_tasks, n_threads);
2015220183 cplan.work_size = work_size;
20184+ cplan.q_size = q_size;
2015320185 cplan.work_data = NULL;
2015420186
2015520187 return cplan;
@@ -20168,6 +20200,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
2016820200 /*.ith =*/ state->ith,
2016920201 /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
2017020202 /*.wsize =*/ cplan->work_size,
20203+ /*.qsize =*/ cplan->q_size,
2017120204 /*.wdata =*/ cplan->work_data,
2017220205 /*.threadpool=*/ tp,
2017320206 };
0 commit comments