Skip to content

Commit 1350201

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 8906d11 + dbfd151 commit 1350201

File tree

13 files changed

+306
-43
lines changed

13 files changed

+306
-43
lines changed

common/common.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10121012
params.fused_moe_up_gate = true;
10131013
return true;
10141014
}
1015+
if (arg == "-ger" || arg == "--grouped-expert-routing") {
1016+
params.grouped_expert_routing = true;
1017+
return true;
1018+
}
10151019
if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
10161020
params.fused_up_gate = false;
10171021
return true;
@@ -1800,6 +1804,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
18001804
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
18011805
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
18021806
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
1807+
options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" });
18031808
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
18041809
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
18051810
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
@@ -2755,6 +2760,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
27552760
cparams.mla_attn = params.mla_attn;
27562761
cparams.attn_max_batch = params.attn_max_batch;
27572762
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
2763+
cparams.grouped_expert_routing = params.grouped_expert_routing;
27582764
cparams.fused_up_gate = params.fused_up_gate;
27592765
cparams.min_experts = params.min_experts;
27602766
cparams.thresh_experts = params.thresh_experts;
@@ -3871,6 +3877,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
38713877
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
38723878
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
38733879
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
3880+
fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false");
38743881
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
38753882
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
38763883
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ struct gpt_params {
235235
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
236236
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
237237
bool fused_up_gate = true; // fused up*unary(gate) op
238+
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)
238239
int min_experts = -1;
239240
float thresh_experts = 0;
240241

examples/llama-bench/llama-bench.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ struct cmd_params {
261261
bool warmup;
262262
bool repack = false;
263263
bool fmoe = false;
264+
bool ger = false; // ger = Grouped Expert Routing
264265
bool no_fug = false;
265266
bool use_thp = false;
266267
output_formats output_format;
@@ -296,9 +297,10 @@ static const cmd_params cmd_params_defaults = {
296297
/* verbose */ false,
297298
/* warmup */ true,
298299
/* repack */ false,
299-
/* use_thp */ false,
300300
/* fmoe */ false,
301+
/* ger */ false,
301302
/* no_fug */ false,
303+
/* use_thp */ false,
302304
/* output_format */ MARKDOWN,
303305
/* output_format_stderr */ NONE,
304306
};
@@ -341,6 +343,7 @@ static void print_usage(int /* argc */, char ** argv) {
341343
printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0");
342344
printf(" -ot, --override-tensor pattern (default: none)\n");
343345
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
346+
printf(" -ger, --grouped-expert-routing <0|1>(default: %s)\n", cmd_params_defaults.ger ? "1" : "0");
344347
printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0");
345348
printf("\n");
346349
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
@@ -739,6 +742,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
739742
break;
740743
}
741744
params.fmoe = std::stoi(argv[i]);
745+
} else if (arg == "-ger" || arg == "--grouped-expert-routing") {
746+
if (++i >= argc) {
747+
invalid_param = true;
748+
break;
749+
}
750+
params.ger = std::stoi(argv[i]);
742751
} else if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
743752
if (++i >= argc) {
744753
invalid_param = true;
@@ -829,6 +838,7 @@ struct cmd_params_instance {
829838
bool embeddings;
830839
bool repack = false;
831840
bool fmoe = false;
841+
bool ger = false;
832842
bool no_fug = false;
833843
bool use_thp = false;
834844
const llama_model_tensor_buft_override* buft_overrides;
@@ -876,6 +886,7 @@ struct cmd_params_instance {
876886
cparams.mla_attn = mla_attn;
877887
cparams.attn_max_batch = attn_max_batch;
878888
cparams.fused_moe_up_gate = fmoe;
889+
cparams.grouped_expert_routing = ger;
879890
cparams.fused_up_gate = !no_fug;
880891
cparams.min_experts = ser.first;
881892
cparams.thresh_experts = ser.second;
@@ -935,6 +946,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
935946
/* .embeddings = */ embd,
936947
/* .repack = */ params.repack,
937948
/* .fmoe = */ params.fmoe,
949+
/* .ger = */ params.ger,
938950
/* .no_fug = */ params.no_fug,
939951
/* .use_thp = */ params.use_thp,
940952
/* .buft_overrides=*/ params.buft_overrides.data(),
@@ -970,6 +982,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
970982
/* .embeddings = */ embd,
971983
/* .repack = */ params.repack,
972984
/* .fmoe = */ params.fmoe,
985+
/* .ger = */ params.ger,
973986
/* .no_fug = */ params.no_fug,
974987
/* .use_thp = */ params.use_thp,
975988
/* .buft_overrides=*/ params.buft_overrides.data(),
@@ -1005,6 +1018,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10051018
/* .embeddings = */ embd,
10061019
/* .repack = */ params.repack,
10071020
/* .fmoe = */ params.fmoe,
1021+
/* .ger = */ params.ger,
10081022
/* .no_fug = */ params.no_fug,
10091023
/* .use_thp = */ params.use_thp,
10101024
/* .buft_overrides=*/ params.buft_overrides.data(),
@@ -1040,6 +1054,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10401054
/* .embeddings = */ embd,
10411055
/* .repack = */ params.repack,
10421056
/* .fmoe = */ params.fmoe,
1057+
/* .ger = */ params.ger,
10431058
/* .no_fug = */ params.no_fug,
10441059
/* .use_thp = */ params.use_thp,
10451060
/* .buft_overrides=*/ params.buft_overrides.data(),
@@ -1086,6 +1101,7 @@ struct test {
10861101
bool embeddings;
10871102
bool repack = false;
10881103
bool fmoe = false;
1104+
bool ger = false;
10891105
bool no_fug = false;
10901106
bool use_thp = false;
10911107
int n_prompt;
@@ -1120,6 +1136,8 @@ struct test {
11201136
use_mmap = inst.use_mmap;
11211137
embeddings = inst.embeddings;
11221138
repack = inst.repack;
1139+
fmoe = inst.fmoe;
1140+
ger = inst.ger;
11231141
no_fug = inst.no_fug;
11241142
use_thp = inst.use_thp;
11251143
n_prompt = inst.n_prompt;
@@ -1212,7 +1230,7 @@ struct test {
12121230
"n_threads", "type_k", "type_v",
12131231
"n_gpu_layers", "split_mode",
12141232
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
1215-
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "fused_up_gate", "use_thp",
1233+
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "grouped_er", "fused_up_gate", "use_thp",
12161234
"n_prompt", "n_gen", "test_time",
12171235
"avg_ns", "stddev_ns",
12181236
"avg_ts", "stddev_ts", "test",
@@ -1234,7 +1252,7 @@ struct test {
12341252
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
12351253
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
12361254
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" ||
1237-
field == "fused_moe" || field == "fused_up_gate") {
1255+
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate") {
12381256
return BOOL;
12391257
}
12401258
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1277,7 +1295,8 @@ struct test {
12771295
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
12781296
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
12791297
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
1280-
std::to_string(repack), std::to_string(fmoe), std::to_string(no_fug), std::to_string(use_thp),
1298+
std::to_string(repack), std::to_string(fmoe), std::to_string(ger),
1299+
std::to_string(no_fug), std::to_string(use_thp),
12811300
std::to_string(n_prompt), std::to_string(n_gen), test_time,
12821301
std::to_string(avg_ns()), std::to_string(stdev_ns()),
12831302
std::to_string(avg_ts()), std::to_string(stdev_ts()),
@@ -1461,6 +1480,9 @@ struct markdown_printer : public printer {
14611480
if (field == "fused_moe") {
14621481
return 4;
14631482
}
1483+
if (field == "grouped_er") {
1484+
return 3;
1485+
}
14641486
if (field == "fused_up_gate") {
14651487
return 6;
14661488
}
@@ -1513,6 +1535,12 @@ struct markdown_printer : public printer {
15131535
if (field == "fused_moe") {
15141536
return "fmoe";
15151537
}
1538+
if (field == "grouped_er") {
1539+
return "ger";
1540+
}
1541+
if (field == "grouped_er") {
1542+
return "ger";
1543+
}
15161544
if (field == "fused_up_gate") {
15171545
return "no-fug";
15181546
}
@@ -1589,6 +1617,9 @@ struct markdown_printer : public printer {
15891617
if (params.fmoe != cmd_params_defaults.fmoe) {
15901618
fields.emplace_back("fused_moe");
15911619
}
1620+
if (params.ger != cmd_params_defaults.ger) {
1621+
fields.emplace_back("grouped_er");
1622+
}
15921623
if (params.no_fug != cmd_params_defaults.no_fug) {
15931624
fields.emplace_back("fused_up_gate");
15941625
}

ggml/include/ggml.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ extern "C" {
650650
GGML_OP_TIMESTEP_EMBEDDING,
651651
GGML_OP_ARGSORT,
652652
GGML_OP_ARGSORT_THRESH,
653+
GGML_OP_GROUPED_TOPK,
653654
GGML_OP_LEAKY_RELU,
654655
GGML_OP_SOFTCAP,
655656
GGML_OP_SOFT_CAP_MAX,
@@ -2265,6 +2266,13 @@ extern "C" {
22652266
int k,
22662267
int min_entries,
22672268
float thresh);
2269+
GGML_API struct ggml_tensor * ggml_grouped_topk(
2270+
struct ggml_context * ctx,
2271+
struct ggml_tensor * a,
2272+
int num_groups,
2273+
int num_top_groups,
2274+
int nk,
2275+
int topk_experts);
22682276

22692277
#define GGML_KQ_MASK_PAD 16
22702278

ggml/src/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ if (GGML_BLAS)
256256
endif()
257257
endif()
258258

259-
set (GGML_SOURCES_IQK iqk/iqk_quantize.cpp)
260-
set (GGML_HEADERS_IQK iqk/iqk_config.h)
259+
set (GGML_SOURCES_IQK iqk/iqk_quantize.cpp iqk/iqk_cpu_ops.cpp)
260+
set (GGML_HEADERS_IQK iqk/iqk_config.h iqk/iqk_cpu_ops.h)
261261
if (GGML_IQK_MUL_MAT)
262262
message(STATUS "Using optimized iqk matrix multiplications")
263263
add_compile_definitions(GGML_USE_IQK_MULMAT)

0 commit comments

Comments
 (0)