Skip to content

Commit dbfd151

Browse files
ikawrakowIwan Kawrakow
andauthored
Grouped expert routing (CPU only) (ikawrakow#836)
* Better argsort (CPU) * Attemt at grouped topk * This seems to do the trick for grouped experts routing * Cleanup * Trying to merge, something is not right * Working merged grouped top_k (CPU) * Add command line option to enable grouped expert routing * Add grouped expert routing option to llama-bench --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent ecf8f93 commit dbfd151

File tree

11 files changed

+220
-43
lines changed

11 files changed

+220
-43
lines changed

common/common.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10121012
params.fused_moe_up_gate = true;
10131013
return true;
10141014
}
1015+
if (arg == "-ger" || arg == "--grouped-expert-routing") {
1016+
params.grouped_expert_routing = true;
1017+
return true;
1018+
}
10151019
if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
10161020
params.fused_up_gate = false;
10171021
return true;
@@ -1800,6 +1804,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
18001804
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
18011805
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
18021806
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
1807+
options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" });
18031808
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
18041809
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
18051810
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
@@ -2755,6 +2760,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
27552760
cparams.mla_attn = params.mla_attn;
27562761
cparams.attn_max_batch = params.attn_max_batch;
27572762
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
2763+
cparams.grouped_expert_routing = params.grouped_expert_routing;
27582764
cparams.fused_up_gate = params.fused_up_gate;
27592765
cparams.min_experts = params.min_experts;
27602766
cparams.thresh_experts = params.thresh_experts;
@@ -3871,6 +3877,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
38713877
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
38723878
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
38733879
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
3880+
fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false");
38743881
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
38753882
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
38763883
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ struct gpt_params {
235235
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
236236
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
237237
bool fused_up_gate = true; // fused up*unary(gate) op
238+
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)
238239
int min_experts = -1;
239240
float thresh_experts = 0;
240241

examples/llama-bench/llama-bench.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ struct cmd_params {
261261
bool warmup;
262262
bool repack = false;
263263
bool fmoe = false;
264+
bool ger = false; // ger = Grouped Expert Routing
264265
bool no_fug = false;
265266
bool use_thp = false;
266267
output_formats output_format;
@@ -296,9 +297,10 @@ static const cmd_params cmd_params_defaults = {
296297
/* verbose */ false,
297298
/* warmup */ true,
298299
/* repack */ false,
299-
/* use_thp */ false,
300300
/* fmoe */ false,
301+
/* ger */ false,
301302
/* no_fug */ false,
303+
/* use_thp */ false,
302304
/* output_format */ MARKDOWN,
303305
/* output_format_stderr */ NONE,
304306
};
@@ -341,6 +343,7 @@ static void print_usage(int /* argc */, char ** argv) {
341343
printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0");
342344
printf(" -ot, --override-tensor pattern (default: none)\n");
343345
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
346+
printf(" -ger, --grouped-expert-routing <0|1>(default: %s)\n", cmd_params_defaults.ger ? "1" : "0");
344347
printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0");
345348
printf("\n");
346349
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
@@ -739,6 +742,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
739742
break;
740743
}
741744
params.fmoe = std::stoi(argv[i]);
745+
} else if (arg == "-ger" || arg == "--grouped-expert-routing") {
746+
if (++i >= argc) {
747+
invalid_param = true;
748+
break;
749+
}
750+
params.ger = std::stoi(argv[i]);
742751
} else if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
743752
if (++i >= argc) {
744753
invalid_param = true;
@@ -829,6 +838,7 @@ struct cmd_params_instance {
829838
bool embeddings;
830839
bool repack = false;
831840
bool fmoe = false;
841+
bool ger = false;
832842
bool no_fug = false;
833843
bool use_thp = false;
834844
const llama_model_tensor_buft_override* buft_overrides;
@@ -876,6 +886,7 @@ struct cmd_params_instance {
876886
cparams.mla_attn = mla_attn;
877887
cparams.attn_max_batch = attn_max_batch;
878888
cparams.fused_moe_up_gate = fmoe;
889+
cparams.grouped_expert_routing = ger;
879890
cparams.fused_up_gate = !no_fug;
880891
cparams.min_experts = ser.first;
881892
cparams.thresh_experts = ser.second;
@@ -935,6 +946,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
935946
/* .embeddings = */ embd,
936947
/* .repack = */ params.repack,
937948
/* .fmoe = */ params.fmoe,
949+
/* .ger = */ params.ger,
938950
/* .no_fug = */ params.no_fug,
939951
/* .use_thp = */ params.use_thp,
940952
/* .buft_overrides=*/ params.buft_overrides.data(),
@@ -970,6 +982,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
970982
/* .embeddings = */ embd,
971983
/* .repack = */ params.repack,
972984
/* .fmoe = */ params.fmoe,
985+
/* .ger = */ params.ger,
973986
/* .no_fug = */ params.no_fug,
974987
/* .use_thp = */ params.use_thp,
975988
/* .buft_overrides=*/ params.buft_overrides.data(),
@@ -1005,6 +1018,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10051018
/* .embeddings = */ embd,
10061019
/* .repack = */ params.repack,
10071020
/* .fmoe = */ params.fmoe,
1021+
/* .ger = */ params.ger,
10081022
/* .no_fug = */ params.no_fug,
10091023
/* .use_thp = */ params.use_thp,
10101024
/* .buft_overrides=*/ params.buft_overrides.data(),
@@ -1040,6 +1054,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10401054
/* .embeddings = */ embd,
10411055
/* .repack = */ params.repack,
10421056
/* .fmoe = */ params.fmoe,
1057+
/* .ger = */ params.ger,
10431058
/* .no_fug = */ params.no_fug,
10441059
/* .use_thp = */ params.use_thp,
10451060
/* .buft_overrides=*/ params.buft_overrides.data(),
@@ -1086,6 +1101,7 @@ struct test {
10861101
bool embeddings;
10871102
bool repack = false;
10881103
bool fmoe = false;
1104+
bool ger = false;
10891105
bool no_fug = false;
10901106
bool use_thp = false;
10911107
int n_prompt;
@@ -1120,6 +1136,8 @@ struct test {
11201136
use_mmap = inst.use_mmap;
11211137
embeddings = inst.embeddings;
11221138
repack = inst.repack;
1139+
fmoe = inst.fmoe;
1140+
ger = inst.ger;
11231141
no_fug = inst.no_fug;
11241142
use_thp = inst.use_thp;
11251143
n_prompt = inst.n_prompt;
@@ -1212,7 +1230,7 @@ struct test {
12121230
"n_threads", "type_k", "type_v",
12131231
"n_gpu_layers", "split_mode",
12141232
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
1215-
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "fused_up_gate", "use_thp",
1233+
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "grouped_er", "fused_up_gate", "use_thp",
12161234
"n_prompt", "n_gen", "test_time",
12171235
"avg_ns", "stddev_ns",
12181236
"avg_ts", "stddev_ts", "test",
@@ -1234,7 +1252,7 @@ struct test {
12341252
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
12351253
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
12361254
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" ||
1237-
field == "fused_moe" || field == "fused_up_gate") {
1255+
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate") {
12381256
return BOOL;
12391257
}
12401258
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1277,7 +1295,8 @@ struct test {
12771295
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
12781296
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
12791297
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
1280-
std::to_string(repack), std::to_string(fmoe), std::to_string(no_fug), std::to_string(use_thp),
1298+
std::to_string(repack), std::to_string(fmoe), std::to_string(ger),
1299+
std::to_string(no_fug), std::to_string(use_thp),
12811300
std::to_string(n_prompt), std::to_string(n_gen), test_time,
12821301
std::to_string(avg_ns()), std::to_string(stdev_ns()),
12831302
std::to_string(avg_ts()), std::to_string(stdev_ts()),
@@ -1461,6 +1480,9 @@ struct markdown_printer : public printer {
14611480
if (field == "fused_moe") {
14621481
return 4;
14631482
}
1483+
if (field == "grouped_er") {
1484+
return 3;
1485+
}
14641486
if (field == "fused_up_gate") {
14651487
return 6;
14661488
}
@@ -1513,6 +1535,12 @@ struct markdown_printer : public printer {
15131535
if (field == "fused_moe") {
15141536
return "fmoe";
15151537
}
1538+
if (field == "grouped_er") {
1539+
return "ger";
1540+
}
1541+
if (field == "grouped_er") {
1542+
return "ger";
1543+
}
15161544
if (field == "fused_up_gate") {
15171545
return "no-fug";
15181546
}
@@ -1589,6 +1617,9 @@ struct markdown_printer : public printer {
15891617
if (params.fmoe != cmd_params_defaults.fmoe) {
15901618
fields.emplace_back("fused_moe");
15911619
}
1620+
if (params.ger != cmd_params_defaults.ger) {
1621+
fields.emplace_back("grouped_er");
1622+
}
15921623
if (params.no_fug != cmd_params_defaults.no_fug) {
15931624
fields.emplace_back("fused_up_gate");
15941625
}

ggml/include/ggml.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ extern "C" {
650650
GGML_OP_TIMESTEP_EMBEDDING,
651651
GGML_OP_ARGSORT,
652652
GGML_OP_ARGSORT_THRESH,
653+
GGML_OP_GROUPED_TOPK,
653654
GGML_OP_LEAKY_RELU,
654655
GGML_OP_SOFTCAP,
655656
GGML_OP_SOFT_CAP_MAX,
@@ -2265,6 +2266,13 @@ extern "C" {
22652266
int k,
22662267
int min_entries,
22672268
float thresh);
2269+
GGML_API struct ggml_tensor * ggml_grouped_topk(
2270+
struct ggml_context * ctx,
2271+
struct ggml_tensor * a,
2272+
int num_groups,
2273+
int num_top_groups,
2274+
int nk,
2275+
int topk_experts);
22682276

22692277
#define GGML_KQ_MASK_PAD 16
22702278

ggml/src/ggml.c

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4253,6 +4253,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
42534253
"TIMESTEP_EMBEDDING",
42544254
"ARGSORT",
42554255
"ARGSORT_THRESH",
4256+
"GROUPED_TOPK",
42564257
"LEAKY_RELU",
42574258
"SOFTCAP",
42584259
"SOFT_CAP_MAX",
@@ -4288,7 +4289,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
42884289
"GLU",
42894290
};
42904291

4291-
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
4292+
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
42924293

42934294
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
42944295
"none",
@@ -4356,6 +4357,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
43564357
"timestep_embedding(timesteps, dim, max_period)",
43574358
"argsort(x)",
43584359
"argsort_thresh(x)",
4360+
"grouped_topk(x)",
43594361
"leaky_relu(x)",
43604362
"k2*tanh(k1*x)",
43614363
"soft_max(k2*tanh(k1*x))",
@@ -4391,7 +4393,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
43914393
"glu(x),"
43924394
};
43934395

4394-
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
4396+
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
43954397

43964398
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
43974399

@@ -9439,6 +9441,39 @@ struct ggml_tensor * ggml_argsort_thresh(
94399441
return result;
94409442
}
94419443

9444+
struct ggml_tensor * ggml_grouped_topk(
9445+
struct ggml_context * ctx,
9446+
struct ggml_tensor * a,
9447+
int num_groups,
9448+
int num_top_groups,
9449+
int nk,
9450+
int topk_experts) {
9451+
9452+
GGML_ASSERT(num_top_groups <= num_groups);
9453+
GGML_ASSERT(a->ne[0] % num_groups == 0);
9454+
GGML_ASSERT(a->ne[0] >= topk_experts);
9455+
int64_t n_per_group = a->ne[0] / num_groups;
9456+
GGML_ASSERT(n_per_group >= nk);
9457+
9458+
bool is_node = false;
9459+
9460+
int64_t ne[GGML_MAX_DIMS];
9461+
for (int i = 1; i < GGML_MAX_DIMS; ++i) ne[i] = a->ne[i];
9462+
ne[0] = topk_experts;
9463+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, ne);
9464+
9465+
ggml_set_op_params_i32(result, 0, num_groups);
9466+
ggml_set_op_params_i32(result, 1, num_top_groups);
9467+
ggml_set_op_params_i32(result, 2, nk);
9468+
9469+
result->op = GGML_OP_GROUPED_TOPK;
9470+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
9471+
result->src[0] = a;
9472+
9473+
return result;
9474+
}
9475+
9476+
94429477
// ggml_top_k
94439478

94449479
struct ggml_tensor * ggml_top_k(
@@ -20024,6 +20059,24 @@ static void ggml_compute_forward_argsort_thresh(
2002420059
}
2002520060
}
2002620061

20062+
static void ggml_compute_forward_grouped_topk(
20063+
const struct ggml_compute_params * params,
20064+
struct ggml_tensor * dst) {
20065+
20066+
const struct ggml_tensor * src0 = dst->src[0];
20067+
20068+
switch (src0->type) {
20069+
case GGML_TYPE_F32:
20070+
{
20071+
iqk_grouped_top_k(dst, params->ith, params->nth);
20072+
} break;
20073+
default:
20074+
{
20075+
GGML_ABORT("fatal error");
20076+
}
20077+
}
20078+
}
20079+
2002720080
// ggml_compute_forward_flash_attn_ext
2002820081

2002920082
static void ggml_compute_forward_flash_attn_ext_f16(
@@ -22521,6 +22574,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
2252122574
{
2252222575
ggml_compute_forward_argsort_thresh(params, tensor);
2252322576
} break;
22577+
case GGML_OP_GROUPED_TOPK:
22578+
{
22579+
ggml_compute_forward_grouped_topk(params, tensor);
22580+
} break;
2252422581
case GGML_OP_LEAKY_RELU:
2252522582
{
2252622583
ggml_compute_forward_leaky_relu(params, tensor);
@@ -23539,6 +23596,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
2353923596
{
2354023597
GGML_ABORT("fatal error"); // TODO: not implemented
2354123598
}
23599+
case GGML_OP_GROUPED_TOPK:
23600+
{
23601+
GGML_ABORT("fatal error"); // TODO: not implemented
23602+
}
2354223603
case GGML_OP_LEAKY_RELU:
2354323604
{
2354423605
GGML_ABORT("fatal error"); // TODO: not implemented
@@ -24281,6 +24342,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2428124342
case GGML_OP_TIMESTEP_EMBEDDING:
2428224343
case GGML_OP_ARGSORT:
2428324344
case GGML_OP_ARGSORT_THRESH:
24345+
case GGML_OP_GROUPED_TOPK:
2428424346
case GGML_OP_FLASH_ATTN_EXT:
2428524347
case GGML_OP_FLASH_ATTN_BACK:
2428624348
case GGML_OP_SSM_CONV:

0 commit comments

Comments
 (0)